Compare commits
10 Commits
feat/neuro
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
60f5598542
|
|||
|
7945240646
|
|||
|
0c74d89d15
|
|||
|
c94a2ae755
|
|||
|
99920dd322
|
|||
|
c4f239ceb9
|
|||
|
ac445c1569
|
|||
|
abc6e605b8
|
|||
|
4f2957af9e
|
|||
|
75cd088b61
|
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -905,8 +905,7 @@ dependencies = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "cudarc"
|
name = "cudarc"
|
||||||
version = "0.19.7"
|
version = "0.19.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "git+https://github.com/grenade/cudarc?rev=63327a256059f8252641ae46c6bb9eefe707f382#63327a256059f8252641ae46c6bb9eefe707f382"
|
||||||
checksum = "1cea5f10a99e025c1b44ae2354c2d8326b25ddbd0baf76bde8e55cfd4018a2cc"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"float8",
|
"float8",
|
||||||
"half",
|
"half",
|
||||||
|
|||||||
@@ -61,3 +61,12 @@ eventsource-stream = "0.2"
|
|||||||
# workspace crates
|
# workspace crates
|
||||||
cortex-core = { path = "crates/cortex-core" }
|
cortex-core = { path = "crates/cortex-core" }
|
||||||
cortex-gateway = { path = "crates/cortex-gateway" }
|
cortex-gateway = { path = "crates/cortex-gateway" }
|
||||||
|
|
||||||
|
# Patched cudarc (affects neuron's 0.19.x only; candle's 0.17.x is
|
||||||
|
# untouched since the fork is 0.19.7 and doesn't satisfy a 0.17 req). Adds
|
||||||
|
# Comm::abort / get_async_error / raw comm() — needed for #17 Stage 2 TP
|
||||||
|
# hang-recovery (abort a wedged collective from another thread, then
|
||||||
|
# rebuild the comm). Pinned to a fork revision pending upstream review
|
||||||
|
# (grenade/cudarc @ nccl-comm-abort).
|
||||||
|
[patch.crates-io]
|
||||||
|
cudarc = { git = "https://github.com/grenade/cudarc", rev = "63327a256059f8252641ae46c6bb9eefe707f382" }
|
||||||
|
|||||||
@@ -60,6 +60,17 @@ pub struct CandleHarness {
|
|||||||
/// can still load on CPU for tests, just without worker threads).
|
/// can still load on CPU for tests, just without worker threads).
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
device_workers: Arc<RwLock<HashMap<u32, Arc<super::device_worker::DeviceWorkerHandle>>>>,
|
device_workers: Arc<RwLock<HashMap<u32, Arc<super::device_worker::DeviceWorkerHandle>>>>,
|
||||||
|
/// Auto-recovery (#17): model ids whose poisoned context is being
|
||||||
|
/// rebuilt via unload+reload. Insert is the single-flight gate (one
|
||||||
|
/// recovery per model in flight); membership also lets the request
|
||||||
|
/// path answer "recovering, retry shortly" during the reload gap
|
||||||
|
/// rather than a bare "not loaded".
|
||||||
|
recovering: Arc<RwLock<std::collections::HashSet<String>>>,
|
||||||
|
/// Sender to the background recovery task. The request path enqueues
|
||||||
|
/// a poisoned model id here; the task (holding a `Weak<Self>`) runs
|
||||||
|
/// the unload→reload→health-gate. Unbounded + tiny (model ids), and
|
||||||
|
/// the `recovering` set dedupes, so it can't back up.
|
||||||
|
recovery_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// One entry in the harness's loaded-model registry. Single-GPU loads
|
/// One entry in the harness's loaded-model registry. Single-GPU loads
|
||||||
@@ -86,6 +97,15 @@ impl LoadedHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The spec this model was loaded from (for auto-recovery #17).
|
||||||
|
pub fn spec(&self) -> &ModelSpec {
|
||||||
|
match self {
|
||||||
|
LoadedHandle::Single(m) => &m.spec,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
LoadedHandle::Tp(m) => &m.spec,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn devices(&self) -> Vec<u32> {
|
pub fn devices(&self) -> Vec<u32> {
|
||||||
match self {
|
match self {
|
||||||
LoadedHandle::Single(m) => m.devices.clone(),
|
LoadedHandle::Single(m) => m.devices.clone(),
|
||||||
@@ -215,6 +235,10 @@ pub struct LoadedModel {
|
|||||||
/// `(h/factor) × (w/factor)` (#14 dynamic resolution). `None` for
|
/// `(h/factor) × (w/factor)` (#14 dynamic resolution). `None` for
|
||||||
/// text-only models. Set at load time.
|
/// text-only models. Set at load time.
|
||||||
pub image_grid_factor: Option<usize>,
|
pub image_grid_factor: Option<usize>,
|
||||||
|
/// The spec this model was loaded from — retained so auto-recovery
|
||||||
|
/// (#17) can `unload_model` + `load_model(spec)` a poisoned model
|
||||||
|
/// without an operator reconstructing it.
|
||||||
|
pub spec: ModelSpec,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LoadedModel {
|
impl LoadedModel {
|
||||||
@@ -289,6 +313,9 @@ pub struct TpLoadedModel {
|
|||||||
/// Pixel→LM-grid divisor — same as
|
/// Pixel→LM-grid divisor — same as
|
||||||
/// [`LoadedModel::image_grid_factor`].
|
/// [`LoadedModel::image_grid_factor`].
|
||||||
pub image_grid_factor: Option<usize>,
|
pub image_grid_factor: Option<usize>,
|
||||||
|
/// Loading spec, retained for auto-recovery (#17) — see
|
||||||
|
/// [`LoadedModel::spec`].
|
||||||
|
pub spec: ModelSpec,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
@@ -792,6 +819,46 @@ fn poisoned_error(model_id: &str) -> InferenceError {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reported while auto-recovery (#17) is rebuilding a poisoned model's
|
||||||
|
/// context. Unlike [`poisoned_error`] this is a *transient* state — the
|
||||||
|
/// model is being reloaded automatically; the client should retry.
|
||||||
|
fn recovering_error(model_id: &str) -> InferenceError {
|
||||||
|
InferenceError::Other(anyhow::anyhow!(
|
||||||
|
"model '{model_id}' is recovering (its device context was poisoned \
|
||||||
|
by an earlier failure and is being automatically rebuilt); retry \
|
||||||
|
shortly"
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verification hook for #17 auto-recovery. When `NEURON_DEBUG_POISON`
|
||||||
|
/// names a model, the **first** request for it (process-wide) returns
|
||||||
|
/// true, so the request path can trigger recovery as if a device fault
|
||||||
|
/// had occurred — exercising the unload→reload→healthy cycle without
|
||||||
|
/// corrupting the GPU. One-shot (a `swap` latch) so it can't loop the
|
||||||
|
/// model through endless recoveries. No-op unless the env var is set.
|
||||||
|
fn debug_poison_armed(model_id: &str) -> bool {
|
||||||
|
static FIRED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||||
|
let armed = std::env::var("NEURON_DEBUG_POISON").ok().as_deref() == Some(model_id);
|
||||||
|
armed && !FIRED.swap(true, Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Background auto-recovery task (#17). Drains poisoned model ids and
|
||||||
|
/// rebuilds each via [`CandleHarness::recover_one`]. Holds a `Weak` so a
|
||||||
|
/// shutting-down harness lets the task exit; processes one id at a time,
|
||||||
|
/// which (with the `recovering` set deduping enqueues) keeps recovery
|
||||||
|
/// single-flight per model.
|
||||||
|
async fn recovery_loop(
|
||||||
|
weak: std::sync::Weak<CandleHarness>,
|
||||||
|
mut rx: tokio::sync::mpsc::UnboundedReceiver<String>,
|
||||||
|
) {
|
||||||
|
while let Some(model_id) = rx.recv().await {
|
||||||
|
let Some(this) = weak.upgrade() else {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
this.recover_one(&model_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
|
/// Free/total VRAM on the candle `Device` in MiB. Returns `(0, 0)` if
|
||||||
/// the query fails or the device is the CPU fallback so logging never
|
/// the query fails or the device is the CPU fallback so logging never
|
||||||
/// crashes the request path. Mirrors the existing helper in
|
/// crashes the request path. Mirrors the existing helper in
|
||||||
@@ -1146,7 +1213,7 @@ impl CandleHarness {
|
|||||||
/// Construct a new harness for `bind_url` using `config`. Resolves
|
/// Construct a new harness for `bind_url` using `config`. Resolves
|
||||||
/// every configured source's auth env var and cache dir up front so
|
/// every configured source's auth env var and cache dir up front so
|
||||||
/// the hot load path (`hf_api_for`) is a pure HashMap lookup.
|
/// the hot load path (`hf_api_for`) is a pure HashMap lookup.
|
||||||
pub fn new(bind_url: String, config: &crate::config::CandleHarnessConfig) -> Self {
|
pub fn new(bind_url: String, config: &crate::config::CandleHarnessConfig) -> Arc<Self> {
|
||||||
let raw_sources = config.effective_sources();
|
let raw_sources = config.effective_sources();
|
||||||
let default_source = config.effective_default_source().to_string();
|
let default_source = config.effective_default_source().to_string();
|
||||||
let mut sources = HashMap::with_capacity(raw_sources.len());
|
let mut sources = HashMap::with_capacity(raw_sources.len());
|
||||||
@@ -1196,13 +1263,25 @@ impl CandleHarness {
|
|||||||
bare model ids will fail to resolve until this is fixed"
|
bare model ids will fail to resolve until this is fixed"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Self {
|
let (recovery_tx, recovery_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||||
|
let this = Arc::new(Self {
|
||||||
models: Arc::new(RwLock::new(HashMap::new())),
|
models: Arc::new(RwLock::new(HashMap::new())),
|
||||||
sources,
|
sources,
|
||||||
default_source,
|
default_source,
|
||||||
bind_url,
|
bind_url,
|
||||||
device_workers: Arc::new(RwLock::new(HashMap::new())),
|
device_workers: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
recovering: Arc::new(RwLock::new(std::collections::HashSet::new())),
|
||||||
|
recovery_tx,
|
||||||
|
});
|
||||||
|
// Background auto-recovery task (#17). Holds a `Weak` so it can't
|
||||||
|
// keep the harness alive. Spawned only when a tokio runtime is
|
||||||
|
// present — sync unit tests that build a harness without one
|
||||||
|
// simply skip it (they don't exercise recovery).
|
||||||
|
if tokio::runtime::Handle::try_current().is_ok() {
|
||||||
|
let weak = Arc::downgrade(&this);
|
||||||
|
tokio::spawn(recovery_loop(weak, recovery_rx));
|
||||||
}
|
}
|
||||||
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Scheme to substitute for bare `org/name` model ids. Mirrors the
|
/// Scheme to substitute for bare `org/name` model ids. Mirrors the
|
||||||
@@ -1627,7 +1706,17 @@ impl CandleHarness {
|
|||||||
let models = self.models.read().await;
|
let models = self.models.read().await;
|
||||||
models.get(&request.model).cloned()
|
models.get(&request.model).cloned()
|
||||||
};
|
};
|
||||||
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
let handle = match handle {
|
||||||
|
Some(h) => h,
|
||||||
|
// Absent from the registry: distinguish a genuinely unloaded
|
||||||
|
// model from one whose slot is briefly gone mid auto-recovery
|
||||||
|
// (#17), so the client gets a transient "retry shortly" instead
|
||||||
|
// of a misleading "not loaded".
|
||||||
|
None if self.is_recovering(&request.model).await => {
|
||||||
|
return Err(recovering_error(&request.model));
|
||||||
|
}
|
||||||
|
None => return Err(InferenceError::ModelNotLoaded(request.model.clone())),
|
||||||
|
};
|
||||||
// The match is technically infallible without `cuda` (only Single
|
// The match is technically infallible without `cuda` (only Single
|
||||||
// exists), but the cfg-gated Tp arm makes this the right shape
|
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||||
// under both feature flags.
|
// under both feature flags.
|
||||||
@@ -1657,7 +1746,12 @@ impl CandleHarness {
|
|||||||
if loaded.poisoned.load(Ordering::Acquire) {
|
if loaded.poisoned.load(Ordering::Acquire) {
|
||||||
let _g = span.enter();
|
let _g = span.enter();
|
||||||
tracing::warn!("chat_completion: refusing request, model poisoned");
|
tracing::warn!("chat_completion: refusing request, model poisoned");
|
||||||
return Err(poisoned_error(&model_id));
|
return Err(self.trigger_recovery(&model_id).await);
|
||||||
|
}
|
||||||
|
if debug_poison_armed(&model_id) {
|
||||||
|
let _g = span.enter();
|
||||||
|
tracing::warn!("NEURON_DEBUG_POISON: forcing auto-recovery (#17 verification)");
|
||||||
|
return Err(self.trigger_recovery(&model_id).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialise concurrent requests against this model. Holds for
|
// Serialise concurrent requests against this model. Holds for
|
||||||
@@ -2036,7 +2130,17 @@ impl CandleHarness {
|
|||||||
let models = self.models.read().await;
|
let models = self.models.read().await;
|
||||||
models.get(&request.model).cloned()
|
models.get(&request.model).cloned()
|
||||||
};
|
};
|
||||||
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
let handle = match handle {
|
||||||
|
Some(h) => h,
|
||||||
|
// Absent from the registry: distinguish a genuinely unloaded
|
||||||
|
// model from one whose slot is briefly gone mid auto-recovery
|
||||||
|
// (#17), so the client gets a transient "retry shortly" instead
|
||||||
|
// of a misleading "not loaded".
|
||||||
|
None if self.is_recovering(&request.model).await => {
|
||||||
|
return Err(recovering_error(&request.model));
|
||||||
|
}
|
||||||
|
None => return Err(InferenceError::ModelNotLoaded(request.model.clone())),
|
||||||
|
};
|
||||||
// The match is technically infallible without `cuda` (only Single
|
// The match is technically infallible without `cuda` (only Single
|
||||||
// exists), but the cfg-gated Tp arm makes this the right shape
|
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||||
// under both feature flags.
|
// under both feature flags.
|
||||||
@@ -2129,7 +2233,7 @@ impl CandleHarness {
|
|||||||
// Refuse if the model is already poisoned. No point opening
|
// Refuse if the model is already poisoned. No point opening
|
||||||
// an SSE stream just to send the Start event and then bail.
|
// an SSE stream just to send the Start event and then bail.
|
||||||
if loaded.poisoned.load(Ordering::Acquire) {
|
if loaded.poisoned.load(Ordering::Acquire) {
|
||||||
return Err(poisoned_error(&model_id));
|
return Err(self.trigger_recovery(&model_id).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start event: tells the wire projector to emit its
|
// Start event: tells the wire projector to emit its
|
||||||
@@ -2347,6 +2451,69 @@ pub struct InferenceStream {
|
|||||||
pub reasoning_markers: Option<ReasoningTokenPair>,
|
pub reasoning_markers: Option<ReasoningTokenPair>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Auto-recovery (#17) — rebuild a poisoned model's device context
|
||||||
|
/// automatically instead of leaving it bricked until a human reloads.
|
||||||
|
impl CandleHarness {
|
||||||
|
/// True while `model_id` is being auto-recovered (its slot is briefly
|
||||||
|
/// absent from the registry during the reload).
|
||||||
|
pub async fn is_recovering(&self, model_id: &str) -> bool {
|
||||||
|
self.recovering.read().await.contains(model_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Single-flight trigger from the request path: enqueue a rebuild for a
|
||||||
|
/// poisoned model (only the first caller per model enqueues) and return
|
||||||
|
/// the transient "recovering" error to hand back to the client.
|
||||||
|
async fn trigger_recovery(&self, model_id: &str) -> InferenceError {
|
||||||
|
let newly = self.recovering.write().await.insert(model_id.to_string());
|
||||||
|
if newly {
|
||||||
|
tracing::warn!(model = %model_id, "auto-recovery: poisoned, enqueueing rebuild");
|
||||||
|
if self.recovery_tx.send(model_id.to_string()).is_err() {
|
||||||
|
// Background task gone (harness shutting down). Drop the
|
||||||
|
// marker and fall back to the manual-reload message.
|
||||||
|
self.recovering.write().await.remove(model_id);
|
||||||
|
tracing::error!(model = %model_id, "auto-recovery: task unavailable");
|
||||||
|
return poisoned_error(model_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recovering_error(model_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rebuild a poisoned model: `unload_model` (drops it → cudarc aborts
|
||||||
|
/// NCCL + releases the context) then `load_model` from the retained
|
||||||
|
/// spec. A successful reload re-runs NCCL init + sanity inside the load
|
||||||
|
/// path, so it returns a fresh, healthy model; a failed reload leaves
|
||||||
|
/// the model unloaded (recoverable by the next load), never poisoned
|
||||||
|
/// forever. Runs on the background task — never inline on the request
|
||||||
|
/// path (would deadlock on the `models` write lock).
|
||||||
|
async fn recover_one(&self, model_id: &str) {
|
||||||
|
let spec = {
|
||||||
|
let models = self.models.read().await;
|
||||||
|
models.get(model_id).map(|h| h.spec().clone())
|
||||||
|
};
|
||||||
|
let Some(spec) = spec else {
|
||||||
|
self.recovering.write().await.remove(model_id);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
tracing::warn!(model = %model_id, "auto-recovery: unload+reload starting");
|
||||||
|
if let Err(e) = self.unload_model(model_id).await {
|
||||||
|
tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
"auto-recovery: unload failed (continuing to reload)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
match self.load_model(&spec).await {
|
||||||
|
Ok(()) => tracing::info!(model = %model_id, "auto-recovery: reloaded; model healthy"),
|
||||||
|
Err(e) => tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
"auto-recovery: reload failed; model left unloaded"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
self.recovering.write().await.remove(model_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Harness for CandleHarness {
|
impl Harness for CandleHarness {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
@@ -2550,6 +2717,7 @@ impl Harness for CandleHarness {
|
|||||||
has_vision: vision_meta.has_vision,
|
has_vision: vision_meta.has_vision,
|
||||||
image_token_id: vision_meta.image_token_id,
|
image_token_id: vision_meta.image_token_id,
|
||||||
image_grid_factor: vision_meta.image_grid_factor,
|
image_grid_factor: vision_meta.image_grid_factor,
|
||||||
|
spec: spec.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -2788,6 +2956,7 @@ impl CandleHarness {
|
|||||||
has_vision: vision_meta.has_vision,
|
has_vision: vision_meta.has_vision,
|
||||||
image_token_id: vision_meta.image_token_id,
|
image_token_id: vision_meta.image_token_id,
|
||||||
image_grid_factor: vision_meta.image_grid_factor,
|
image_grid_factor: vision_meta.image_grid_factor,
|
||||||
|
spec: spec.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut models = self.models.write().await;
|
let mut models = self.models.write().await;
|
||||||
@@ -2834,7 +3003,12 @@ impl CandleHarness {
|
|||||||
if tp.poisoned.load(Ordering::Acquire) {
|
if tp.poisoned.load(Ordering::Acquire) {
|
||||||
let _g = span.enter();
|
let _g = span.enter();
|
||||||
tracing::warn!("TP chat_completion: refusing request, model poisoned");
|
tracing::warn!("TP chat_completion: refusing request, model poisoned");
|
||||||
return Err(poisoned_error(&model_id));
|
return Err(self.trigger_recovery(&model_id).await);
|
||||||
|
}
|
||||||
|
if debug_poison_armed(&model_id) {
|
||||||
|
let _g = span.enter();
|
||||||
|
tracing::warn!("NEURON_DEBUG_POISON: forcing auto-recovery (#17 verification)");
|
||||||
|
return Err(self.trigger_recovery(&model_id).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject image-bearing requests against a TP model with no
|
// Reject image-bearing requests against a TP model with no
|
||||||
@@ -2923,7 +3097,7 @@ impl CandleHarness {
|
|||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<InferenceStream, InferenceError> {
|
) -> Result<InferenceStream, InferenceError> {
|
||||||
if tp.poisoned.load(Ordering::Acquire) {
|
if tp.poisoned.load(Ordering::Acquire) {
|
||||||
return Err(poisoned_error(&request.model));
|
return Err(self.trigger_recovery(&request.model).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject image requests against a non-vision TP model before
|
// Reject image requests against a non-vision TP model before
|
||||||
|
|||||||
@@ -201,6 +201,16 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
let _ = reply.send(resp);
|
let _ = reply.send(resp);
|
||||||
}
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::GetLeaderComm { reply } => {
|
||||||
|
// Clone the leader's Arc<Comm> out for the async-side
|
||||||
|
// watchdog. `None` before NcclInit. (#17 Stage 2)
|
||||||
|
let comm = state
|
||||||
|
.nccl
|
||||||
|
.comm()
|
||||||
|
.map(crate::harness::tp::nccl_state::SendComm);
|
||||||
|
let _ = reply.send(comm);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
Job::TpLoadShard {
|
Job::TpLoadShard {
|
||||||
model_id,
|
model_id,
|
||||||
config_json,
|
config_json,
|
||||||
@@ -1004,6 +1014,10 @@ fn drain_poisoned(job: Job, device_index: u32) {
|
|||||||
message: format!("device worker {device_index} poisoned"),
|
message: format!("device worker {device_index} poisoned"),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::GetLeaderComm { reply } => {
|
||||||
|
let _ = reply.send(None);
|
||||||
|
}
|
||||||
Job::NcclSanity { reply } => {
|
Job::NcclSanity { reply } => {
|
||||||
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||||
kind: "device_worker_poisoned".into(),
|
kind: "device_worker_poisoned".into(),
|
||||||
|
|||||||
@@ -192,6 +192,17 @@ pub enum Job {
|
|||||||
NcclSanity {
|
NcclSanity {
|
||||||
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||||
},
|
},
|
||||||
|
/// Hand a clonable handle to the leader's NCCL `Comm` back to the
|
||||||
|
/// async side, so the TP step watchdog can call `ncclCommAbort` on
|
||||||
|
/// it from a *different* thread to unblock a wedged collective
|
||||||
|
/// (#17 Stage 2). Fetched once at init while the worker thread is
|
||||||
|
/// still responsive — a thread already wedged in a collective can't
|
||||||
|
/// service this job, which is exactly why the handle is cached
|
||||||
|
/// up front. Replies `None` before `NcclInit` has run.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
GetLeaderComm {
|
||||||
|
reply: oneshot::Sender<Option<crate::harness::tp::nccl_state::SendComm>>,
|
||||||
|
},
|
||||||
/// Load the leader's TP shard on the worker thread. The dispatch
|
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||||
/// handler reads `state.nccl.comm()` directly (no cross-thread
|
/// handler reads `state.nccl.comm()` directly (no cross-thread
|
||||||
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
|
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
|
||||||
|
|||||||
@@ -161,6 +161,27 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fetch a clonable handle to the leader's NCCL `Comm` (#17 Stage 2).
|
||||||
|
/// The TP step watchdog caches this at init so it can call
|
||||||
|
/// `ncclCommAbort` from the async thread to unblock a wedged
|
||||||
|
/// collective. Returns `None` if uninitialised, poisoned, or gone —
|
||||||
|
/// the caller treats a missing handle as "can't abort" and logs it.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn get_leader_comm(&self) -> Option<crate::harness::tp::nccl_state::SendComm> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
if self
|
||||||
|
.tx
|
||||||
|
.send(Job::GetLeaderComm { reply: reply_tx })
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
reply_rx.await.ok().flatten()
|
||||||
|
}
|
||||||
|
|
||||||
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||||
/// thread. The hf-hub resolution happens on the async caller; the
|
/// thread. The hf-hub resolution happens on the async caller; the
|
||||||
/// resolved local `gguf_path` plus the spec's model_id are sent
|
/// resolved local `gguf_path` plus the spec's model_id are sent
|
||||||
|
|||||||
@@ -114,10 +114,8 @@ impl HarnessRegistry {
|
|||||||
for config in configs {
|
for config in configs {
|
||||||
match config.name.as_str() {
|
match config.name.as_str() {
|
||||||
"candle" => {
|
"candle" => {
|
||||||
let harness = Arc::new(candle::CandleHarness::new(
|
let harness =
|
||||||
bind_url.to_string(),
|
candle::CandleHarness::new(bind_url.to_string(), &settings.candle);
|
||||||
&settings.candle,
|
|
||||||
));
|
|
||||||
registry.candle = Some(Arc::clone(&harness));
|
registry.candle = Some(Arc::clone(&harness));
|
||||||
registry.harnesses.insert("candle".into(), harness);
|
registry.harnesses.insert("candle".into(), harness);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,12 +55,23 @@ pub struct PreprocessProfile {
|
|||||||
pub image_std: [f32; 3],
|
pub image_std: [f32; 3],
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Default pixel budget for Qwen3.6 (`256² … 1024²` → 64 … 1024 LM
|
/// The Qwen3.6 vision tower rejects any image whose **patch** count
|
||||||
/// tokens/image). Generous for documents/OCR, bounded for serving on
|
/// exceeds its learned pos-embed budget (`num_position_embeddings =
|
||||||
/// 2×RTX5090. Operators tune with `NEURON_VISION_MIN_PIXELS` /
|
/// 2304 = 48²`; see `vision.rs`). At `patch_size = 16` that is
|
||||||
/// `NEURON_VISION_MAX_PIXELS` (matching the other `NEURON_VISION_*` knobs).
|
/// `2304 × 16² = 589_824` source pixels. `max_pixels` is hard-capped to
|
||||||
|
/// this so `smart_resize` can never produce an over-budget grid — a
|
||||||
|
/// per-rank "patch count exceeds pos_embed budget" error mid-TP-forward
|
||||||
|
/// would otherwise poison the device context. The pos-embed grid is the
|
||||||
|
/// resolution Qwen3.6 was trained at, so this cap is principled, not just
|
||||||
|
/// defensive.
|
||||||
|
const QWEN3_6_MAX_PIXELS_CAP: u32 = 2304 * 16 * 16; // 589_824 → ≤ 2304 patches → ≤ 576 LM tokens
|
||||||
|
|
||||||
|
/// Default pixel budget for Qwen3.6: `256²` (64 LM tokens) up to the
|
||||||
|
/// pos-embed cap (576 LM tokens). Generous for documents/OCR, bounded
|
||||||
|
/// for serving. Operators lower it with `NEURON_VISION_MIN_PIXELS` /
|
||||||
|
/// `NEURON_VISION_MAX_PIXELS` (the upper bound is still clamped to the
|
||||||
|
/// cap above — raising it past the budget would poison the model).
|
||||||
const QWEN3_6_MIN_PIXELS: u32 = 65_536;
|
const QWEN3_6_MIN_PIXELS: u32 = 65_536;
|
||||||
const QWEN3_6_MAX_PIXELS: u32 = 1_048_576;
|
|
||||||
|
|
||||||
fn env_pixels(name: &str, default: u32) -> u32 {
|
fn env_pixels(name: &str, default: u32) -> u32 {
|
||||||
std::env::var(name)
|
std::env::var(name)
|
||||||
@@ -72,15 +83,19 @@ fn env_pixels(name: &str, default: u32) -> u32 {
|
|||||||
impl PreprocessProfile {
|
impl PreprocessProfile {
|
||||||
/// Profile for Qwen3.6. Native-aspect `smart_resize` (factor 32),
|
/// Profile for Qwen3.6. Native-aspect `smart_resize` (factor 32),
|
||||||
/// normalise to `[-1, 1]` via mean=std=0.5. Pixel budget defaults to
|
/// normalise to `[-1, 1]` via mean=std=0.5. Pixel budget defaults to
|
||||||
/// [`QWEN3_6_MIN_PIXELS`]…[`QWEN3_6_MAX_PIXELS`], overridable via the
|
/// [`QWEN3_6_MIN_PIXELS`]…[`QWEN3_6_MAX_PIXELS_CAP`], overridable via
|
||||||
/// `NEURON_VISION_MIN_PIXELS` / `NEURON_VISION_MAX_PIXELS` env vars.
|
/// `NEURON_VISION_MIN_PIXELS` / `NEURON_VISION_MAX_PIXELS`. Clamped
|
||||||
/// The budget is clamped sane: `min ≥ factor²` (at least one LM token)
|
/// sane: `factor² ≤ min ≤ max`, and `max ≤` the pos-embed cap (so the
|
||||||
/// and `max ≥ min`.
|
/// vision tower never rejects a resized image and poisons the context).
|
||||||
pub fn qwen3_6() -> Self {
|
pub fn qwen3_6() -> Self {
|
||||||
let factor = 32u32;
|
let factor = 32u32;
|
||||||
let f2 = factor * factor;
|
let f2 = factor * factor;
|
||||||
let min_pixels = env_pixels("NEURON_VISION_MIN_PIXELS", QWEN3_6_MIN_PIXELS).max(f2);
|
let min_pixels = env_pixels("NEURON_VISION_MIN_PIXELS", QWEN3_6_MIN_PIXELS)
|
||||||
let max_pixels = env_pixels("NEURON_VISION_MAX_PIXELS", QWEN3_6_MAX_PIXELS).max(min_pixels);
|
.max(f2)
|
||||||
|
.min(QWEN3_6_MAX_PIXELS_CAP);
|
||||||
|
let max_pixels = env_pixels("NEURON_VISION_MAX_PIXELS", QWEN3_6_MAX_PIXELS_CAP)
|
||||||
|
.min(QWEN3_6_MAX_PIXELS_CAP)
|
||||||
|
.max(min_pixels);
|
||||||
Self {
|
Self {
|
||||||
factor,
|
factor,
|
||||||
min_pixels,
|
min_pixels,
|
||||||
@@ -388,6 +403,28 @@ mod tests {
|
|||||||
assert!(format!("{err:#}").contains("200:1"));
|
assert!(format!("{err:#}").contains("200:1"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn qwen3_6_never_exceeds_pos_embed_patch_budget() {
|
||||||
|
// The pos-embed cap must hold for huge, tall, wide, and extreme
|
||||||
|
// images — exceeding 2304 patches errors mid-tower and poisons
|
||||||
|
// the device context, so this invariant is load-bearing.
|
||||||
|
let p = PreprocessProfile::qwen3_6();
|
||||||
|
for (sh, sw) in [
|
||||||
|
(8000u32, 6000u32),
|
||||||
|
(808, 1600),
|
||||||
|
(4000, 400),
|
||||||
|
(1, 199),
|
||||||
|
(16, 16),
|
||||||
|
] {
|
||||||
|
let (h, w) = p.resized_dims(sh, sw).unwrap();
|
||||||
|
let patches = (h / 16) * (w / 16);
|
||||||
|
assert!(
|
||||||
|
patches <= 2304,
|
||||||
|
"{sh}x{sw} → {h}x{w} = {patches} patches exceeds the 2304 budget"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn qwen3_6_default_budget_bounds_lm_tokens() {
|
fn qwen3_6_default_budget_bounds_lm_tokens() {
|
||||||
// A huge source image caps at max_pixels → the per-image LM token
|
// A huge source image caps at max_pixels → the per-image LM token
|
||||||
|
|||||||
@@ -245,9 +245,67 @@ pub struct WorkerPool {
|
|||||||
/// Phase 4 the load itself moves onto the worker and that bridge
|
/// Phase 4 the load itself moves onto the worker and that bridge
|
||||||
/// goes away.
|
/// goes away.
|
||||||
pub(crate) leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
pub(crate) leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
||||||
|
/// Cached handle to the leader's NCCL `Comm`, fetched at `init_nccl`
|
||||||
|
/// while the worker thread is responsive. The TP step watchdog uses
|
||||||
|
/// it to `ncclCommAbort` a wedged collective from the async thread —
|
||||||
|
/// the one NCCL op allowed concurrently with an in-flight collective,
|
||||||
|
/// and the only way to unblock the in-process leader thread so
|
||||||
|
/// recovery's `unload` doesn't itself hang (#17 Stage 2). `None` if
|
||||||
|
/// init couldn't cache it; the watchdog then logs that it can't abort.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
leader_comm: Option<nccl_state::SendComm>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-step deadline for a TP forward (#17 Stage 2). A healthy decode
|
||||||
|
/// step or chunked prefill completes in well under a second; a wedged
|
||||||
|
/// NCCL collective never returns. Generous default so no legitimate step
|
||||||
|
/// trips it; overridable via `NEURON_TP_STEP_TIMEOUT_S` (seconds).
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn tp_step_timeout() -> std::time::Duration {
|
||||||
|
let secs = std::env::var("NEURON_TP_STEP_TIMEOUT_S")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| v.trim().parse::<u64>().ok())
|
||||||
|
.filter(|&s| s > 0)
|
||||||
|
.unwrap_or(120);
|
||||||
|
std::time::Duration::from_secs(secs)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WorkerPool {
|
impl WorkerPool {
|
||||||
|
/// Abort the leader's NCCL comm to unblock a collective the watchdog
|
||||||
|
/// found wedged (#17 Stage 2). Logs the whole sequence loudly so a
|
||||||
|
/// real-world hang leaves a greppable forensic trail
|
||||||
|
/// (`tp watchdog:` / `ncclCommAbort`). Calling abort from this async
|
||||||
|
/// thread while the worker thread is blocked inside the collective is
|
||||||
|
/// the one concurrent NCCL op the library sanctions — it is how a
|
||||||
|
/// stuck/failed collective is unblocked.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn watchdog_abort_leader_comm(&self, model_id: &str, secs: u64) {
|
||||||
|
tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
timeout_s = secs,
|
||||||
|
"tp watchdog: leader forward exceeded deadline — NCCL collective wedged; \
|
||||||
|
aborting comm to unblock the leader thread for auto-recovery"
|
||||||
|
);
|
||||||
|
match &self.leader_comm {
|
||||||
|
Some(c) => match c.0.abort() {
|
||||||
|
Ok(()) => tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
"tp watchdog: ncclCommAbort succeeded — wedged collective unblocked; \
|
||||||
|
failing the step so the model auto-recovers (unload+reload)"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::error!(
|
||||||
|
model = %model_id, error = ?e,
|
||||||
|
"tp watchdog: ncclCommAbort failed — recovery may stall until a process restart"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
None => tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
"tp watchdog: no cached leader comm handle — cannot abort; recovery will rely \
|
||||||
|
on a process restart"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
||||||
/// leader (in-process) and is *not* spawned here — the leader
|
/// leader (in-process) and is *not* spawned here — the leader
|
||||||
/// holds rank 0's NCCL Comm and shard in its own address space.
|
/// holds rank 0's NCCL Comm and shard in its own address space.
|
||||||
@@ -324,6 +382,8 @@ impl WorkerPool {
|
|||||||
workers,
|
workers,
|
||||||
exe,
|
exe,
|
||||||
leader_worker,
|
leader_worker,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
leader_comm: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -404,6 +464,23 @@ impl WorkerPool {
|
|||||||
world_size = self.world_size,
|
world_size = self.world_size,
|
||||||
"NCCL communicator established across all ranks"
|
"NCCL communicator established across all ranks"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Cache the leader's Comm handle now, while the worker thread is
|
||||||
|
// responsive, so the TP step watchdog can abort a wedged
|
||||||
|
// collective later (it can't fetch it then — the thread is stuck).
|
||||||
|
// (#17 Stage 2.)
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
self.leader_comm = self.leader_worker.get_leader_comm().await;
|
||||||
|
if self.leader_comm.is_some() {
|
||||||
|
tracing::debug!("cached leader NCCL comm handle for the TP step watchdog");
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"could not cache leader NCCL comm handle; the TP step watchdog will be \
|
||||||
|
unable to abort a wedged collective (a hang would need a process restart)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -628,10 +705,27 @@ impl WorkerPool {
|
|||||||
// that's the invariant the whole refactor exists to
|
// that's the invariant the whole refactor exists to
|
||||||
// preserve.
|
// preserve.
|
||||||
let leader_start = std::time::Instant::now();
|
let leader_start = std::time::Instant::now();
|
||||||
let leader_result = self
|
let timeout = tp_step_timeout();
|
||||||
|
let leader_fut = self
|
||||||
.leader_worker
|
.leader_worker
|
||||||
.tp_forward_logits(leader_handle, tokens, offset)
|
.tp_forward_logits(leader_handle, tokens, offset);
|
||||||
.await;
|
let leader_result = match tokio::time::timeout(timeout, leader_fut).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(_elapsed) => {
|
||||||
|
// Watchdog (#17 Stage 2): the NCCL collective is wedged.
|
||||||
|
// Abort the leader comm to unblock its thread, then fail
|
||||||
|
// the step WITHOUT draining (the subprocess workers are
|
||||||
|
// wedged too; recovery's unload kills them). The error
|
||||||
|
// poisons the model → auto-recovery, which no longer hangs
|
||||||
|
// because the leader thread is now responsive.
|
||||||
|
self.watchdog_abort_leader_comm(model_id, timeout.as_secs());
|
||||||
|
anyhow::bail!(
|
||||||
|
"tp watchdog: leader forward exceeded {}s deadline; aborted wedged NCCL \
|
||||||
|
comm — model will auto-recover",
|
||||||
|
timeout.as_secs()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
let leader_ok = leader_result.is_ok();
|
let leader_ok = leader_result.is_ok();
|
||||||
let leader_ms = leader_start.elapsed().as_millis();
|
let leader_ms = leader_start.elapsed().as_millis();
|
||||||
// Surface the leader's own error at WARN before draining
|
// Surface the leader's own error at WARN before draining
|
||||||
@@ -767,17 +861,29 @@ impl WorkerPool {
|
|||||||
// matching collective; CPU-side logits keep the device tensor
|
// matching collective; CPU-side logits keep the device tensor
|
||||||
// from escaping the worker thread.
|
// from escaping the worker thread.
|
||||||
let leader_start = std::time::Instant::now();
|
let leader_start = std::time::Instant::now();
|
||||||
let leader_result = self
|
let timeout = tp_step_timeout();
|
||||||
.leader_worker
|
let leader_fut = self.leader_worker.tp_forward_logits_with_images(
|
||||||
.tp_forward_logits_with_images(
|
leader_handle,
|
||||||
leader_handle,
|
tokens,
|
||||||
tokens,
|
offset,
|
||||||
offset,
|
image_token_id,
|
||||||
image_token_id,
|
image_data_uris,
|
||||||
image_data_uris,
|
chunk_size,
|
||||||
chunk_size,
|
);
|
||||||
)
|
let leader_result = match tokio::time::timeout(timeout, leader_fut).await {
|
||||||
.await;
|
Ok(r) => r,
|
||||||
|
Err(_elapsed) => {
|
||||||
|
// Watchdog (#17 Stage 2) — see generate_step. Vision
|
||||||
|
// prefill is still well under the deadline on healthy
|
||||||
|
// hardware; a timeout means a wedged collective.
|
||||||
|
self.watchdog_abort_leader_comm(model_id, timeout.as_secs());
|
||||||
|
anyhow::bail!(
|
||||||
|
"tp watchdog: leader image forward exceeded {}s deadline; aborted wedged \
|
||||||
|
NCCL comm — model will auto-recover",
|
||||||
|
timeout.as_secs()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
let leader_ok = leader_result.is_ok();
|
let leader_ok = leader_result.is_ok();
|
||||||
let leader_ms = leader_start.elapsed().as_millis();
|
let leader_ms = leader_start.elapsed().as_millis();
|
||||||
if !leader_ok {
|
if !leader_ok {
|
||||||
|
|||||||
@@ -119,40 +119,25 @@ mod cuda_impl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
|
/// Thin newtype over `Arc<Comm>`, kept for call-site clarity — it marks
|
||||||
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
|
/// the points where a comm handle is intentionally moved across threads
|
||||||
/// given comm must be serialised", not "the handle must stay on the
|
/// (e.g. cached async-side for the TP step watchdog's `ncclCommAbort`).
|
||||||
/// thread that created it" — so it's safe to move an `Arc<Comm>`
|
|
||||||
/// across threads as long as no concurrent ops are issued. The
|
|
||||||
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
|
|
||||||
/// wrapper at the move boundary is the only thing missing.
|
|
||||||
///
|
///
|
||||||
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
|
/// `Send`/`Sync` are provided upstream by `cudarc`'s `Comm` (which
|
||||||
/// by the row-parallel layers are only used from the
|
/// asserts the NCCL thread-safety invariant, including aborting from a
|
||||||
/// `spawn_blocking` thread driving the forward pass; concurrent
|
/// different thread than one inside a collective), so this type derives
|
||||||
/// access from another thread would still be a bug.
|
/// them automatically — no manual `unsafe impl` here.
|
||||||
pub struct SendComm(pub Arc<Comm>);
|
pub struct SendComm(pub Arc<Comm>);
|
||||||
|
|
||||||
// SAFETY: see the doc-comment above; the invariant is enforced at
|
|
||||||
// the call site (pool Mutex + single spawn_blocking thread), not at
|
|
||||||
// the type level.
|
|
||||||
unsafe impl Send for SendComm {}
|
|
||||||
unsafe impl Sync for SendComm {}
|
|
||||||
|
|
||||||
impl SendComm {
|
impl SendComm {
|
||||||
pub fn into_inner(self) -> Arc<Comm> {
|
pub fn into_inner(self) -> Arc<Comm> {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
// `NcclState`'s `Send`/`Sync` are auto-derived: its `Arc<Comm>` and
|
||||||
// (libnccl-allocated state). NCCL requires that operations against
|
// `Arc<CudaContext>` fields are now `Send`/`Sync` (cudarc asserts the
|
||||||
// one Comm be issued one at a time; we serialise access by storing
|
// comm thread-safety invariant), so no manual `unsafe impl` is needed.
|
||||||
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
|
|
||||||
// move-safe — NCCL doesn't track the calling OS thread, only the
|
|
||||||
// stream the operations are dispatched against.
|
|
||||||
unsafe impl Send for NcclState {}
|
|
||||||
unsafe impl Sync for NcclState {}
|
|
||||||
|
|
||||||
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
|
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
|
||||||
/// the leader to mint the shared communicator id which is then
|
/// the leader to mint the shared communicator id which is then
|
||||||
|
|||||||
Reference in New Issue
Block a user