50 Commits

Author SHA1 Message Date
aa88d37509 fix(gateway): full observability + stop leaking upstream bodies
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Format (push) Successful in 42s
CI / Clippy (push) Successful in 2m27s
build-prerelease / Build neuron-blackwell (push) Successful in 3m39s
CI / Test (push) Successful in 4m42s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ampere (push) Successful in 4m53s
build-prerelease / Build neuron-ada (push) Successful in 5m7s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m58s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m3s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s
Comprehensive sweep across cortex-gateway's request handling. Every
failure path now emits exactly one structured warn (or error) event
on the cortex side with the wire-level detail an operator needs;
the API response carries only a generic message plus, where useful,
the upstream status code.

proxy.rs::forward_request:
- warn on network failure (network error, target URL).
- warn on upstream non-2xx (status, target URL). Streaming body still
  passes through to the client; we just can't snippet without
  breaking the stream.
- warn on response-build failure.
- ProxyError::into_response no longer interpolates the inner error
  into the API body — generic "upstream request failed" / "failed to
  build response" instead.

handlers.rs::chat_completions, handlers.rs::completions:
- warn on missing model field, with handler= label.
- warn on route resolve failure with model + error chain. The
  user-facing 404 keeps the RouteError Display string (which is
  short, informative, and contains no internal detail beyond the
  model id and config'd node names).

handlers.rs::anthropic_messages:
- warn on invalid Anthropic body, on translated-OpenAI serialise
  failure (which is internal), on route resolve, on upstream network
  error, on upstream non-2xx (with 512-char body snippet for parse
  errors), on upstream body read, on response parse.
- All warns share consistent field shape: handler, model, node, url,
  status / error / body as applicable.
- API response messages are now uniformly generic.
- Adds an info-level "proxying request" log on the non-streaming
  path so successful proxies are also visible.

handlers.rs::proxy_with_metrics:
- still calls e.into_response() but proxy::forward_request already
  warn'd at the wire layer, so no double-log here.

Tests:
- All 32 existing unit tests + 22 gateway integration tests + 4
  new router tests pass.
- Tests that asserted on the "no healthy nodes" / "not found"
  strings still match because RouteError messages are preserved
  in the 404 user-facing path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 07:17:26 +03:00
0f00f72b47 fix(router,handlers): strip trailing slash from rewritten URL + log upstream failures
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 33s
CI / Clippy (push) Successful in 2m20s
CI / Test (push) Successful in 4m41s
build-prerelease / Build neuron-blackwell (push) Successful in 3m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Two coupled bugs surfaced after 9b0ed0b:

1. url::Url::parse("http://host:port").to_string() normalises the
   empty path to "/", so rewrite_loopback_host was returning
   "http://beast:13131/". Downstream callers then did
   format!("{endpoint}/v1/chat/completions") and produced a
   double-slash path that neuron's axum router 404'd with an empty
   body. Strip the trailing slash in the rewriter so the endpoint is
   a clean base string for concatenation.

2. The anthropic_messages handler returned the upstream's empty body
   to the API caller as `"upstream error: "` with no journal log on
   the cortex side. Operators had no way to see what happened. Add
   warn-level tracing on both upstream failure paths (network error
   and non-2xx) with model, node, target URL, status, and a 512-char
   body snippet. The API response now carries just `"upstream
   returned <status>"` — the implementation detail lives in the log.

Updates the two existing rewrite tests for the no-trailing-slash
output.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 07:10:39 +03:00
9b0ed0b57f fix(router): rewrite loopback inference URLs to use neuron's host
Some checks failed
CI / Format (push) Successful in 30s
build-prerelease / Resolve version stamps (push) Successful in 41s
build-prerelease / Build neuron-blackwell (push) Successful in 3m34s
CI / Clippy (push) Successful in 7m25s
build-prerelease / Build neuron-ampere (push) Successful in 4m57s
build-prerelease / Build cortex binary (push) Successful in 4m15s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m46s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m6s
CI / Test (push) Failing after 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
Neuron hardcodes its bind_url as `http://localhost:13131` (it can't
reliably know its own externally-resolvable name). When cortex runs
on a different host than the neuron it's routing to, blindly
proxying to that URL hits localhost on the cortex box instead of the
neuron.

Cortex already knows each neuron's reachable host from cortex.toml.
After fetching the inference URL from `/models/{id}/endpoint`, if
the host is a loopback name (localhost / 127.0.0.1 / 0.0.0.0 / ::1),
swap it for the configured neuron host. Preserve the port and path
from neuron's URL so a future harness serving inference on a
different port than the management API still works.

Adds `url` (already a transitive dep via reqwest) as a direct
dep for the URL parsing.

Tests cover: localhost rewrite, distinct inference port preservation,
non-loopback passthrough, malformed input.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 06:23:47 +03:00
dc2a803266 fix(rpm): migrate legacy helexa-cortex firewalld service to cortex
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 33s
CI / Format (push) Successful in 1m1s
CI / Clippy (push) Successful in 3m12s
CI / Test (push) Successful in 4m31s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m52s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
Adds a %posttrans scriptlet to cortex.spec that:

- Removes the stale /etc/firewalld/services/helexa-cortex.xml left
  behind by an older packaging stream that named the service
  `helexa-cortex` and (in some build streams) carried wrong port
  numbers (9301/9302/9304).
- Walks every active firewalld zone; for any zone where the legacy
  helexa-cortex service was enabled, swaps it out for the new
  `cortex` service (which the RPM ships at
  /usr/lib/firewalld/services/cortex.xml with the right
  31313/31314 ports).
- Reloads firewalld so the change takes effect without operator
  intervention.

Operators on whom this happened were silently dropping inbound
connections to cortex on 31313 — the active zone advertised a
helexa-cortex service that listed unrelated ports, masking the
correctly-defined vendor cortex service.

helexa-neuron is unaffected: that spec already ships the vendor
service as helexa-neuron.xml (namespaced from day one) and no
stale /etc override files exist in the fleet.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 06:12:51 +03:00
e71181499e feat(stage-8e-3): quantize lm_head in TP Qwen3-Next
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 42s
build-prerelease / Build neuron-blackwell (push) Successful in 3m43s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m26s
build-prerelease / Build neuron-ampere (push) Successful in 5m23s
build-prerelease / Build neuron-ada (push) Successful in 4m56s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m59s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m42s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
CI / Format (push) Successful in 30s
CI / Clippy (push) Successful in 2m19s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
TpQwen3_5ForCausalLM::lm_head is now a MaybeQuantLinear. When the
load spec has quant set and tie_word_embeddings is false, lm_head's
(vocab_size, hidden_size) weight is quantized in-situ at load time
along with all the per-layer linears. The non-tied case on
Qwen3.6-27B saves ~1.7 GB per rank vs bf16 (248320 x 5120 x 2
bytes = 2.42 GB -> ~700 MB at Q5K) and shaves a small amount of
decode latency from the per-token logits matmul.

Tied case (tie_word_embeddings=true) keeps the lm_head plain even
when quant is set — quantizing the shared tensor would corrupt the
embedding lookup, and the tied case already gets the memory win
from only holding one copy.

This is the last MaybeQuantLinear hookup in the Qwen3-Next TP path.
The dense Qwen3 path (tp_qwen3.rs) is unchanged — defer until it's
the bottleneck for a model that actually needs TP at consumer scale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 21:53:14 +03:00
ee663e5e99 fix(stage-8e-2e): bump quant prefill threshold to M > 64
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Test (push) Waiting to run
CI / Format (push) Successful in 34s
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
The M > 8 threshold from 8e-2d activated forward_via_f16 on the test
case (M=30) and slightly regressed prefill (143 -> 133 T/s). The
dequant cost (~30 MB f16 per linear * ~480 calls per prefill = ~200 ms)
eats the cuBLAS GEMM speedup at small M.

Move the crossover to M > 64 so short prefills (typical for the
validate probe) stay on the GGUF GEMV kernel where per-call cost is
comparable but the dequant tax is zero. Long prefills still get the
dequant-then-cuBLAS-GEMM path where the GEMM scaling amortises the
fixed dequant cost.

Doesn't close the gap to mistralrs's 423 T/s on Q5K prefill — that
needs either a dequant cache (gives back the ISQ memory win) or a
fused dequant+gemm kernel. Both larger projects.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 21:50:45 +03:00
34f9b77d9d feat(stage-8e-2d): route quantized matmul by M (prefill vs decode)
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 41s
CI / Clippy (push) Successful in 2m20s
CI / Test (push) Successful in 4m40s
build-prerelease / Build cortex binary (push) Successful in 4m20s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m58s
build-prerelease / Build neuron-ampere (push) Successful in 5m14s
build-prerelease / Package cortex RPM (push) Successful in 9m25s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m45s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
MaybeQuantLinear::forward picks between two QMatMul paths:

- M > 8 (prefill): QMatMul::forward_via_f16 dequantises the weight
  once into f16 and runs a real cuBLAS-backed GEMM. The dequant cost
  is fixed per call, so it's amortised across the M tokens.
- M <= 8 (decode): QMatMul::forward uses candle's GGUF GEMV kernel
  on the quantized blocks directly. Requires f32 inputs so we still
  cast in/out at the boundary in that arm.

Earlier 8e-2c sent everything through the GGUF GEMV kernel, which
is excellent at GEMV (decode) but doesn't have a real batched GEMM
path — prefill regressed ~4x. This restores prefill to roughly the
bf16 cuBLAS GEMM throughput while keeping the decode gain.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 21:15:32 +03:00
f084aaab8e fix(stage-8e-2c): cast bf16/f16 activations to f32 around QMatMul
All checks were successful
CI / Format (push) Successful in 33s
build-prerelease / Resolve version stamps (push) Successful in 40s
CI / Clippy (push) Successful in 2m18s
CI / Test (push) Successful in 4m26s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m41s
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
build-prerelease / Build neuron-ampere (push) Successful in 5m12s
build-prerelease / Build neuron-ada (push) Successful in 4m41s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m59s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m5s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m48s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m2s
candle's QTensor::cuda_fwd requires f32 inputs — its on-the-fly
GGUF dequantize accumulates in f32. The model dtype flowing into
MaybeQuantLinear::forward is bf16, so QMatMul::forward errored with
"unexpected dtype, expected: F32, got: BF16".

Wrap the Quant arm to cast the activation to f32 before the matmul
and cast the result back to the input dtype. The cast is a single
launch on the activation tensor (small relative to weight traffic);
it's the price of in-situ GGUF-style quantization, and what mistralrs
does inside its own Linear wrapper.

The Plain arm is unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 20:05:19 +03:00
68a606a79c fix(stage-8e-2b): allow quant on the TP load path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 33s
CI / Format (push) Successful in 35s
CI / Clippy (push) Successful in 2m16s
CI / Test (push) Successful in 4m29s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m50s
build-prerelease / Build cortex binary (push) Successful in 8m37s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Package cortex RPM (push) Successful in 1m17s
build-prerelease / Build neuron-ada (push) Successful in 4m55s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m57s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 12m35s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
The pre-existing guard in candle.rs rejected any spec.quant on the TP
path with "GGUF quantized models are not supported in the TP path" —
written when quant only ever meant GGUF. With 8e-1/8e-2 in,
quant != None on the TP path triggers in-situ quantization of the
loaded safetensors shards. resolve_dense_files only looks for
safetensors so a GGUF-source-file model with TP still errors out
cleanly downstream.

validate-neuron.sh: rebuild the load payload incrementally so
tp_size > 1 + non-empty quant produces both fields. Same script now
covers all four combos (single/TP × dense/ISQ).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 19:17:14 +03:00
4aa71902d0 feat(stage-8e-2): plumb quant config from ModelSpec to TP load path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m7s
CI / Test (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m47s
build-prerelease / Build neuron-ampere (push) Successful in 5m17s
build-prerelease / Build neuron-ada (push) Successful in 5m14s
build-prerelease / Build cortex binary (push) Successful in 18m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
- LoadDenseShard RPC gains an optional `quant` string field.
- WorkerPool::load_dense_shard takes a `quant: Option<String>`,
  passes it via the RPC to workers and via parse_quant_string to
  the leader's local load.
- The Qwen3-Next TP load chain (ForCausalLM → Model → DecoderLayer
  → Attention / GatedDeltaNet / MLP) takes `quant: Option<GgmlDType>`
  end-to-end, calling Column/RowParallelLinear::load_with_quant.
- The fused in_proj_qkv inside TpQwen3_5GatedDeltaNet is now a
  MaybeQuantLinear so it also picks up quantization.
- parse_quant_string accepts q4_0/q4_1/q5_0/q5_1/q8_0/q8_1, q2k..q8k
  (with or without underscore), and f16/bf16/f32. Empty / None means
  no quantization.

Callers from candle.rs forward spec.quant through pool.load_dense_shard.
This means a `quant = "q5k"` in models.toml now flows end-to-end to a
QTensor-backed QMatMul for every per-rank linear in the Qwen3-Next
TP path. Leaves lm_head and the small replicated bias/log tensors in
their loaded dtype (Stage 8e-3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 18:03:36 +03:00
bef159b21c feat(stage-8e-1): MaybeQuantLinear primitive + parallel-linear quant variants
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 37s
build-prerelease / Build cortex binary (push) Successful in 4m36s
build-prerelease / Build neuron-blackwell (push) Successful in 3m31s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
CI / Format (push) Waiting to run
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Introduces MaybeQuantLinear, which wraps either a plain candle Linear
or a candle QMatMul backed by a freshly-quantized QTensor. Forward
dispatches identically through the Module trait so downstream code
doesn't care which arm is active.

ColumnParallelLinear and RowParallelLinear gain `load_with_quant`
methods. The existing `load` methods stay as backward-compatible
no-quantization wrappers — no churn at the 27 existing call sites.

This is the foundation for in-situ quantization at load time. Wiring
the user-facing quant config and switching call sites to
load_with_quant follow in stages 8e-2 / 8e-3.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 17:55:26 +03:00
8d7b099b36 feat(stage-8d-7): direct safetensors fused-region loader
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
CI / Format (push) Successful in 35s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Clippy (push) Successful in 2m18s
CI / Test (push) Successful in 4m28s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m51s
build-prerelease / Build cortex binary (push) Successful in 4m13s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Replaces load_fused_qkv_slice_2d/_3d with reads from a separate
MmapedSafetensors handle. Each per-rank fused tensor is built by
reading the three region byte-slices directly from the mmap,
concatenating them host-side, and uploading as one device
allocation — no full-fused-tensor device materialisation.

The prior approach allocated a ~100 MB transient device tensor
per linear-attention layer; on Qwen3.6-27B with 48 linear-attn
layers that's ~4.8 GB of allocator churn during load — enough
to fragment the cuda caching allocator on a tight-VRAM 32 GB
consumer GPU, which is what triggered the layer-22 up_proj
OOM seen on beast.

Threading: MmapedSafetensors flows worker → ForCausalLM →
Model → DecoderLayer → GatedDeltaNet::load. Both leader (mod.rs)
and worker (worker.rs) construct their own mmap; Linux's page
cache shares the underlying pages.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 17:49:35 +03:00
89d98d1fb2 diag(stage-8d-6): per-layer VRAM logging in TP load path
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 33s
CI / Clippy (push) Successful in 2m14s
build-prerelease / Build neuron-blackwell (push) Successful in 3m59s
CI / Test (push) Successful in 4m58s
build-prerelease / Build cortex binary (push) Successful in 4m36s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m26s
build-prerelease / Build neuron-ampere (push) Successful in 4m52s
build-prerelease / Build neuron-ada (push) Successful in 5m11s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m56s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m52s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Wraps each TpQwen3_5DecoderLayer::load in a with_context that captures
free/total VRAM on failure, plus an info-level log after every layer
that succeeds. Uses cudarc::driver::result::mem_get_info — same API
mistralrs uses.

Diagnostic only: forward path is unchanged. Helps distinguish true
VRAM exhaustion from allocator fragmentation when loading large
models at BF16 on 2x consumer GPUs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 12:54:05 +03:00
cc95fe28d9 feat(stage-8d-5b): wire fused_gdn_gating CUDA kernel
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 1m45s
build-prerelease / Build neuron-blackwell (push) Successful in 3m40s
build-prerelease / Build cortex binary (push) Successful in 4m27s
build-prerelease / Package cortex RPM (push) Successful in 1m24s
build-prerelease / Build neuron-ampere (push) Successful in 5m30s
build-prerelease / Build neuron-ada (push) Successful in 5m24s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m49s
CI / Format (push) Successful in 35s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m7s
CI / Clippy (push) Successful in 2m16s
CI / Test (push) Successful in 4m37s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
run_fused_gating helper consolidates the per-layer gating math:
  beta = sigmoid(b)
  g    = -exp(a_log) * softplus(a + dt_bias)

CUDA path issues a single launch via fused_gdn_gating_cuda;
cpu path falls back to the original per-op Rust sequence. Replaces
~10 candle launches per linear-attention layer (sigmoid + 2× to_dtype
+ exp + neg + broadcast_add + softplus + 2× unsqueeze + broadcast_mul)
across both single-GPU and TP forward paths.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 11:52:38 +03:00
09c945f81e feat(stage-8d-4): dispatch chunked_gated_delta_rule_recurrence at prefill
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 31s
CI / Format (push) Successful in 44s
CI / Clippy (push) Failing after 52s
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
run_delta_rule_cuda now picks between the per-token kernel and the
BT=64 chunked variant based on seq_len. Threshold = 64 matches mistralrs.
Prefill on Qwen3.6-27B (typical seq_len in the hundreds) drops from
one block-launch per token to one per 64-token chunk.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 11:50:30 +03:00
05dc0bad18 feat(stage-8d-3): wire causal_conv1d_update/full CUDA kernels
Some checks failed
CI / Clippy (push) Waiting to run
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 38s
build-prerelease / Build cortex binary (push) Has started running
build-prerelease / Build neuron-blackwell (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Replaces the per-layer conv1d + silu sequence in both single-GPU and
TP linear-attention forward paths with a shared run_causal_conv1d
helper that dispatches to:

- causal_conv1d_update for decode (seq_len=1 with existing conv_state)
- causal_conv1d_full for prefill / fresh start (zero-pads internally)

Both kernels fuse the depthwise conv + SiLU into a single launch — 4×
fewer cuda launches per linear-attention layer vs the candle conv1d +
candle_nn::ops::silu combo. Falls back to the original Rust path on
cpu.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 11:49:41 +03:00
10c151efa5 feat(stage-8d-5): wire gated_delta_rule_recurrence kernel into tp_qwen3_5
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m21s
build-prerelease / Build neuron-blackwell (push) Successful in 3m36s
CI / Test (push) Successful in 4m39s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m34s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
TP per-token Rust loop replaced with shared run_delta_rule dispatch
from arch/qwen3_5/linear_attn.rs. Both single-GPU and TP variants now
use the cuda kernel when available, per-token Rust fallback otherwise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 11:44:12 +03:00
44ae927e38 feat(stage-8d-2): wire gated_delta_rule_recurrence kernel into qwen3_5
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Test (push) Failing after 45s
CI / Clippy (push) Successful in 2m16s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
Replaces the per-token Rust delta-rule loop in
`arch/qwen3_5/linear_attn.rs::GatedDeltaNet::forward` with a single
dispatch to the `gated_delta_rule_recurrence` kernel imported from
mistralrs in 1ebbe87.

The kernel is V-tiled with compile-time BK (one block per (V-tile,
batch*head), one thread per V-column, BK state floats in registers).
For Qwen3.6's per-rank `(B=1, H=24, D_k=128, D_v=128)` shape this
collapses ~6 candle tensor-op launches per token per layer (each
~50µs CUDA dispatch overhead, so ~300µs/token/layer × 48 linear-
attention layers = 14ms in launch overhead alone) to a single
kernel launch with full ILP / register residency.

New free function `run_delta_rule`:
- cuda branch (when q is on a CUDA device): flattens
  `(B, H, ...)` → `(BH, ...)`, dispatches the kernel via
  `crate::cuda::gdn::gated_delta_rule_recurrence_cuda`, reshapes
  outputs back to `(B, H, L, D_v)` and state to `(B, H, D_k, D_v)`.
- cpu fallback: the original per-token Rust loop, unchanged. Keeps
  cargo test --workspace passing on hosts without cuda.

Dispatch decision lives in the wrapper (`q.device().is_cuda()`).

Build: `cargo build -p neuron --features cuda` compiles + links;
clippy clean on both CPU and cuda paths. 32 lib tests still pass
(none of them exercise this code path on cuda; smoke test for the
TP variant is the deployed Tbilisi probe).

Stage 8d-3 wires the conv1d kernels; 8d-4 the chunked prefill;
8d-5 the same wiring for `tp/tp_qwen3_5.rs`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 11:39:30 +03:00
1ebbe87651 feat(stage-8d-1): import mistralrs GDN CUDA kernels — build infra only
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Test (push) Waiting to run
build-prerelease / Resolve version stamps (push) Successful in 29s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m23s
build-prerelease / Build neuron-blackwell (push) Has started running
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
Stage 8d (new): port the Gated DeltaNet CUDA kernels from
EricLBuehler/mistral.rs to close the ~500x decode performance gap
we measured on Qwen3.6-27B TP-2 (~12s/token in our pure-candle path
vs ~37 T/s in mistralrs on the same hardware).

This commit lays the build infrastructure with zero behavioural
change. Subsequent commits (8d-2 .. 8d-5) wire each kernel into the
qwen3_5 architecture and TP variant.

Added:
- `crates/neuron/build.rs` — uses `cudaforge::KernelBuilder` to compile
  every `src/cuda/*.cu` file into `libneuroncuda.a` under the `cuda`
  feature, then links it + `cudart`. Mirrors mistralrs's
  `mistralrs-core/build.rs` setup verbatim (same NVCC flag set, same
  sm_<80 bf16 gate).
- `crates/neuron/src/cuda/gdn.cu` — five kernels ported verbatim from
  upstream:
    * `gated_delta_rule_recurrence` (V-tiled per-token decode)
    * `chunked_gated_delta_rule_recurrence` (BT=64 chunked prefill)
    * `causal_conv1d_update` (single-token conv decode)
    * `causal_conv1d_full` (multi-token conv prefill)
    * `fused_gdn_gating` (beta = sigmoid(b); g = -exp(A_log) *
      softplus(a + dt_bias))
- `crates/neuron/src/cuda/gdn.rs` — Rust wrappers around the kernels,
  cudarc::CudaSlice::device_ptr boilerplate identical to upstream.
- `crates/neuron/src/cuda/ffi.rs` — `extern "C"` decls (subset of
  upstream's ffi.rs covering only the five GDN kernels; MoE / SSM /
  top-k decls land here when we absorb those too).
- `crates/neuron/src/cuda/mod.rs` — re-exports + module docs.

Cargo wiring: `cudaforge` added as an optional build-dep, activated
by the `cuda` feature. CPU build is unchanged (the `cuda/` module is
fully `#[cfg(feature = "cuda")]`). The cuda feature build inside the
patched container compiles `gdn.cu` (1 of 1 kernels) and links
clean.

Licensing: upstream files preserve their MIT origin via per-file
comment banners pointing to the mistralrs path. No behaviour-relevant
edits to the .cu kernels — local diff against upstream is just the
banner. The `.rs` wrappers and `ffi.rs` subset are also from upstream;
their structure (module path `crate::cuda::ffi::*`) matches identically
so future kernel imports drop in unchanged.

CPU clippy + 32 lib tests pass; `cargo clippy --features cuda` clean
inside the runner container.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 11:34:11 +03:00
70eb6af42b feat(tp): cancellation-safe inference + structured tracing
All checks were successful
CI / Format (push) Successful in 30s
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Clippy (push) Successful in 2m14s
build-prerelease / Build neuron-blackwell (push) Successful in 3m44s
build-prerelease / Build cortex binary (push) Successful in 4m13s
CI / Test (push) Successful in 4m38s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 4m47s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m54s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m1s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m41s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
Two changes addressing operator visibility into TP inference + the
HTTP-cancellation poisoning chain:

1. `chat_completion_tp` now runs its body inside `tokio::spawn`. When
   the HTTP client disconnects (curl --max-time, browser nav, etc.)
   the future returned from `chat_completion_tp` gets dropped, but
   the spawned task keeps running to completion — finishing every
   `pool.generate_step` / `pool.clear_kv_cache` to drain the worker
   pipes. The next inference request then finds a clean pool.

   Previously: dropped future left workers still processing the
   in-flight request, the next call's `ClearKvCache` recv would
   read the stale `GenerateStepOk` from the abandoned step ("rank N
   expected KvCacheCleared, got GenerateStepOk"). The drain-on-
   leader-error fix from d1a4aad covered Rust-side leader failures
   but not HTTP-layer cancellation, which is what we actually hit
   on the user's Qwen3.6 test.

2. Tracing throughout the TP path so journalctl shows where an
   inference spends its time without needing to surface harness
   internals via the HTTP error body:

   - `chat_completion_tp_inner` (now a free fn so it can run inside
     spawn): `info` at request start (prompt_len, max_new, temp,
     top_p, eos_id), `info` per major phase (prefill complete with
     elapsed_ms, decode complete with elapsed_ms + token count),
     `info` at completion (total_ms, finish_reason). `debug` for
     pool-lock acquisition + kv-cache clear timing. `trace` per
     decode step (next_token, step_ms).

   - `WorkerPool::generate_step` (leader side): `debug` at fan-out,
     `debug` after leader forward returns with elapsed_ms + ok flag,
     `debug` after drain with errors count + total_ms.

   - `WorkerPool::clear_kv_cache`: matching `debug` at fan-out + drain.

   - `worker::handle_generate_step`: `debug` at forward start + done
     with elapsed_ms, `warn` on forward failure with the full error.

The default log filter is already `info,neuron=debug` so the
operator gets every `info` and `debug` line by default; `trace`
needs RUST_LOG=trace for per-step decode timing.

Stage 7c-ii crash-detection is still future work; this is the
minimum that makes the "where did the 120s go" question answerable
from the logs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 08:22:00 +03:00
d1a4aad91d fix(tp): always drain worker responses on leader failure
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 34s
CI / Format (push) Successful in 1m6s
CI / Clippy (push) Successful in 2m56s
build-prerelease / Build neuron-blackwell (push) Successful in 3m40s
CI / Test (push) Successful in 5m1s
build-prerelease / Build cortex binary (push) Successful in 4m36s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m19s
build-prerelease / Build neuron-ampere (push) Successful in 4m29s
build-prerelease / Build neuron-ada (push) Successful in 4m51s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m9s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m4s
The TP-2 inference probe against Qwen3.6-27B surfaced:
    worker rank 1 ClearKvCache: expected KvCacheCleared, got
    GenerateStepOk

Caused by pipe poisoning. The previous shape of `generate_step`:

  for w in workers { w.send_only(GenerateStep) }   // 1. fan-out
  let logits = spawn_blocking(leader.forward)??;   // 2. early return on err
  for w in workers { w.recv_only() }               // 3. drain (skipped on 2's err)

When step 2 returned `Err` (e.g. a dtype mismatch we hadn't seen
before, an OOM, a downstream squeeze that didn't match the shape),
the function bailed before step 3 — but workers had already written
`GenerateStepOk` to their stdout pipes, since their forwards (and
the NCCL collectives inside) completed independently of the leader's
post-collective Rust-side work.

The next call (typically `ClearKvCache` at the start of the *next*
inference request) would then send a fresh request and read those
stale replies as if they were the new operation's. Once a pipe is
poisoned, every subsequent call surfaces the same shape of error
even though nothing's actually broken.

Fix: introduce two helpers in `tp/mod.rs`:

- `drain_workers(workers, check)` — reads exactly one response from
  every worker regardless of individual outcomes. Returns
  `Vec<String>` of `rank N: detail` strings for any non-OK reply.

- `combine_leader_workers(leader, worker_errs, op)` — folds the
  leader's `Result<Result<T>>` (the spawn_blocking shape) with the
  worker drain into a single `Result<T>`. Leader failure takes
  precedence but worker errors get appended so both halves surface.

`generate_step` and `clear_kv_cache` now use this pattern. Worst case:
both halves fail and the operator sees a combined error message;
either way the pipes are always drained so the next call's recv
matches the request it sent.

Note: the model is still poisoned in the current state — the
operator needs to either `POST /models/unload` + reload, or
`systemctl restart neuron`, to recover. The fix prevents *future*
desync; it doesn't repair existing stale pipe state.

Stage 7c-ii crash detection was tracked as the canonical solution to
this class of issue; this is the minimum-viable subset.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 07:39:36 +03:00
95dc8745eb feat(stage-8c): TP-aware Qwen3-Next (tp_qwen3_5)
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m13s
build-prerelease / Build neuron-blackwell (push) Successful in 3m37s
CI / Test (push) Successful in 4m49s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m26s
build-prerelease / Build neuron-ampere (push) Successful in 5m18s
build-prerelease / Package cortex RPM (push) Successful in 7m6s
build-prerelease / Build neuron-ada (push) Successful in 5m13s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m2s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m55s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 5m39s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m1s
Adds `harness/tp/tp_qwen3_5.rs` — the tensor-parallel variant of the
Qwen3-Next architecture — plus the dispatch wiring needed to route a
load through it on both the leader and the workers.

Architecture pieces (all per-rank, follow `tp_qwen3.rs` patterns for
the full-attention layers + a new pattern for linear-attention):

- TpQwen3_5GatedDeltaNet: V-head-dim sharded. `num_v_heads / world_size`
  V-heads per rank, `num_k_heads / world_size` K-heads. `in_proj_z`,
  `in_proj_b`, `in_proj_a`, `A_log`, `dt_bias` shard uniformly along
  the V-head dim. `out_proj` is row-parallel + AllReduce (the only
  collective inside the block). The recurrent state shards 1:1 with
  V-heads — no cross-rank sync inside the delta-rule loop.

  `in_proj_qkv` and `conv1d.weight` are FUSED tensors with three
  regions along dim 0 (`[first key_dim, second key_dim, value_dim]`).
  Standard uniform-slicing doesn't align with the head boundaries —
  rank 0 would end up with `[first half of K_0, full K_1, first half
  of V]`. New `load_fused_qkv_slice_{2d,3d}` helpers load the full
  tensor, narrow per-region per-rank, and `Tensor::cat` the three
  slices into a per-rank fused weight. Transient peak of one full
  tensor per layer during construction; net memory is properly per-
  rank after the full drops.

- TpQwen3_5Attention: column-parallel `q_proj` (the widened
  `2 * num_heads * head_dim` output, including the gate half — shards
  along the head axis so both query AND gate halves stay consistent
  per rank), `k_proj`, `v_proj`; row-parallel `o_proj` with AllReduce.
  Otherwise mirrors `tp_qwen3.rs`'s attention.

- TpQwen3_5MLP, TpQwen3_5DecoderLayer (dispatches on layer_types),
  TpQwen3_5Model (with `model.language_model.*` prefix), and
  TpQwen3_5ForCausalLM (with tied or separate `lm_head` at top level).

Dispatch wiring:

- New `tp::TpLeaderModel` enum holds either Qwen3 or Qwen3_5 variant.
  `WorkerPool::load_dense_shard` now dispatches on `model_type` from
  the config JSON and returns `Arc<Mutex<TpLeaderModel>>`. The two
  downstream methods (`generate_step`, `clear_kv_cache`) thread this
  enum through — the inner forward+clear_kv_cache dispatch happens
  via the enum's pub methods. Adding another TP architecture later is
  one more enum variant + match arms.

- Worker side gets a parallel `WorkerModel` enum + dispatch in
  `handle_load_dense_shard`, branching on the same `model_type`.

- Harness gate `TP_SUPPORTED_MODEL_TYPES` now `["qwen3", "qwen3_5"]`.
  `TpLoadedModel.leader_model` retyped to the enum.

Helpers in `arch/qwen3_5/linear_attn.rs`:
- `softplus` and `repeat_interleave` made `pub(crate)` so the TP
  module reuses them rather than duplicating.

Reuses unchanged: `Qwen3_5RmsNorm` (replicated weight), the gated
`Qwen3_5RmsNormGated` tail, `l2norm`, the `RotaryEmbedding` (partial
RoPE with `partial_rotary_factor` already correct).

CPU build + clippy + 32 lib tests pass; `cargo clippy --features cuda`
also clean inside the patched runner container.

Single inflight risk to call out: tensor names. For full-attention
layers the per-layer prefix is `model.language_model.layers.<i>.self_attn.*`
and for linear-attention layers `model.language_model.layers.<i>.linear_attn.*`
— the same as the single-GPU path. lm_head sits at the top level (not
under `language_model`) — consistent with the single-GPU path that
validated against Qwen3.5-0.8B.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 22:02:42 +03:00
495d3f7c05 fix(qwen3_5): promote beta to F32 alongside q/k/v in delta rule
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 40s
CI / Format (push) Successful in 43s
CI / Clippy (push) Successful in 2m20s
CI / Test (push) Successful in 4m33s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m19s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Build neuron-blackwell (push) Successful in 3m39s
build-prerelease / Build neuron-ampere (push) Successful in 4m46s
build-prerelease / Build neuron-ada (push) Successful in 5m9s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m58s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m9s
The single-GPU dense load of Qwen/Qwen3.5-0.8B succeeded but the first
inference forward bombed with `dtype mismatch in mul, lhs: F32, rhs:
BF16`. Trace through the recurrent delta-rule loop:

  let q = (q.to_dtype(F32)? * scale)?;        // F32
  let k = k.to_dtype(F32)?;                    // F32
  let v = v.to_dtype(F32)?;                    // F32
  // g built from A_log/dt_bias                 // F32
  // beta = sigmoid(b)                          // BF16 (sigmoid preserves dtype)
  ...
  let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
                ^^^^^^^^^^^^^                    ^^^^^^^^^
                F32                              BF16   ← mismatch

`g` was already F32 because it was constructed from `a_log.to_dtype(F32)`
+ `dt_bias.to_dtype(F32)` earlier in the function. `beta` came from
`sigmoid(b)` where `b` was the model dtype (BF16), so beta stayed BF16
and the multiplication tripped candle's dtype-mismatch check.

Promote beta to F32 at the same point we promote q/k/v.

Caught by the validate-neuron.sh probe against Qwen/Qwen3.5-0.8B on
beast — load returned 200, then `POST /v1/chat/completions` returned
the dtype error.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 21:13:19 +03:00
5c4c8e0eba fix(qwen3_5): tensor names are under model.language_model.*, not model.*
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 33s
CI / Format (push) Successful in 35s
CI / Clippy (push) Successful in 2m12s
build-prerelease / Build neuron-blackwell (push) Successful in 3m49s
CI / Test (push) Successful in 4m27s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-ampere (push) Successful in 4m50s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Build cortex binary (push) Successful in 4m14s
build-prerelease / Package cortex RPM (push) Successful in 1m17s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 59s
Qwen3-Next is a multimodal architecture whose text core sits under
`model.language_model.*` — sibling to `model.visual.*` (vision tower)
and to top-level `lm_head` / `mtp.*`. Every text-side tensor in the
safetensors files carries that prefix:

  model.language_model.embed_tokens.weight
  model.language_model.layers.{i}.{input,post_attention}_layernorm.weight
  model.language_model.layers.{i}.linear_attn.{in_proj_*, conv1d.weight, A_log, dt_bias, norm.weight, out_proj.weight}
  model.language_model.layers.{i}.self_attn.{q,k,v,o}_proj.weight + {q,k}_norm.weight
  model.language_model.layers.{i}.mlp.{gate,up,down}_proj.weight
  model.language_model.norm.weight
  lm_head.weight              (top-level; not under language_model)

The single-pre-emptive fix is in Qwen3_5Model::load — derive a
`text_vb = vb.pp("model.language_model")` once and walk
embed_tokens / layers / norm from there. `lm_head` stays at the
top-level VB; that path was already correct.

The non-text tensors (`model.visual.*`, `mtp.*`) are ignored: we
don't reference them, so the safetensors mmap is fine even though
the bytes are loaded into the address space.

After this, the load that was failing at
"cannot find tensor model.embed_tokens.weight" should proceed to
materialising the actual layer weights — where any further bugs
will be substantive architecture issues rather than naming ones.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 16:48:16 +03:00
07c44d5db1 fix(qwen3_5): nested rope_parameters + partial_rotary_factor=0.25
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 34s
CI / Format (push) Successful in 36s
CI / Clippy (push) Successful in 2m16s
CI / Test (push) Successful in 4m37s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m21s
build-prerelease / Build neuron-blackwell (push) Successful in 3m51s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ampere (push) Successful in 5m2s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m40s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m11s
Two interlocked bugs surfaced trying to load Qwen/Qwen3.5-0.8B (and
the same applies to Qwen/Qwen3.6-27B):

1. Qwen3-Next config.json does NOT have a top-level `rope_theta`.
   It lives inside `rope_parameters: { rope_theta, partial_rotary_factor,
   rope_type, mrope_section, mrope_interleaved }`. Our TextConfig
   declared `rope_theta` as a non-optional top-level field, so the
   deserializer bailed with the misleading "missing field
   `rope_theta` at line 74 col 5".

   Replaced with a nested `RopeParameters` struct that mirrors the
   upstream shape. Defaults are conservative (rope_theta=10000,
   partial_rotary_factor=1.0) so a missing or partial block degrades
   to standard full-rotation RoPE rather than failing.

2. `partial_rotary_factor: 0.25` means only `head_dim * 0.25 = 64` of
   the 256 head_dim values get RoPE applied — the rest pass through
   unchanged. Our RotaryEmbedding was building the inv_freq table
   for the full head_dim and rotating everything. Silently wrong
   for every full-attention layer.

   `RotaryEmbedding` now derives `rotary_dim` from
   `head_dim * partial_rotary_factor`, builds its cos/sin tables at
   that smaller size, and in `apply()` splits q/k into (rotate, pass)
   on the last dim, only `rope_slow`-rotates the rotate half, and
   re-concatenates. Mirrors the reference Python's
   `apply_rotary_pos_emb` exactly for the non-trivial
   `partial_rotary_factor` case.

Tests updated: config-deserialise fixture uses the real `rope_parameters`
shape (matching the Qwen3.6-27B and Qwen3.5-0.8B configs). The
linear-attention forward-smoke test was already using full rotation
which still works; just shifted to the nested struct.

After this, the load that previously failed at "parse Qwen3-Next
(qwen3_5) config.json: missing field rope_theta" should reach the
actual safetensors materialisation step.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 16:18:52 +03:00
e7eb3dab6a feat(stage-8c): full-attention layer + decoder + Model + ForCausalLM for qwen3_5
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 37s
CI / Format (push) Successful in 39s
CI / Clippy (push) Successful in 2m19s
CI / Test (push) Successful in 4m50s
build-prerelease / Build cortex binary (push) Successful in 4m21s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m41s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
build-prerelease / Build neuron-ampere (push) Successful in 4m58s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 58s
Completes the single-GPU dense path for Qwen3-Next (Qwen3.6's
architecture). The four new modules wrap the substantive
`linear_attn.rs` (landed previously) with the rest of the
transformer:

- `arch/qwen3_5/rope.rs` — text-side rotary embedding. MRoPE is
  simplified to plain RoPE (the three position grids collapse to one
  for text-only inference); uses candle's `rope_slow` for the
  GLM-style rotate-half rotation.
- `arch/qwen3_5/mlp.rs` — Qwen3_5MLP (SwiGLU: gate/up/down, bias=False).
- `arch/qwen3_5/full_attn.rs` — Qwen3_5Attention with the two
  Qwen3-Next quirks:
  - `q_proj` widened to `2 * num_heads * head_dim`; second half
    sigmoid'd and multiplied into the attention output before `o_proj`.
  - q_norm/k_norm use the `(1+w)*x` RmsNorm variant.
- `arch/qwen3_5/decoder.rs` — Qwen3_5DecoderLayer dispatching on
  `layer_types[i]` to either Full attention or GatedDeltaNet.

`arch/qwen3_5/mod.rs` gets the real `Qwen3_5Model` (embedding + layer
stack + final norm) and `Qwen3_5ForCausalLM` (model + lm_head). The
forward returns `[B, 1, vocab]` to match `qwen3_dense`; the harness's
`squeeze_to_vocab` handles either shape.

Switch: `candle.rs::load_arch_dense` for `model_type=qwen3_5` now
builds a `ShardedVarBuilder` instead of a plain VarBuilder. The
sharded backend falls through to the unsharded path when
`world_size=1`, so single-GPU load is zero-cost; this lets the
forthcoming `tp_qwen3_5.rs` reuse the same load functions without a
second copy.

Verified: cargo build CPU + --features cuda inside the patched
container; clippy clean on both; 32 lib tests still pass. The
ForCausalLM forward no longer bails — but numerical correctness vs
the Python reference hasn't been validated yet (that's the next
step, with the Tbilisi probe).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 15:52:33 +03:00
180274548d feat(stage-8c): linear-attention layer (Qwen3-Next GatedDeltaNet)
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m17s
build-prerelease / Build neuron-blackwell (push) Successful in 3m48s
CI / Test (push) Successful in 5m1s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m36s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 4m39s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m55s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m57s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m4s
Implements the recurrent-path Gated DeltaNet block that occupies 48 of
Qwen3.6's 64 decoder layers (`layer_types[i] == "linear_attention"`).
Ported from `huggingface/transformers/models/qwen3_5/modeling_qwen3_5.py`
(`Qwen3_5GatedDeltaNet`, `torch_recurrent_gated_delta_rule`,
`Qwen3_5RMSNormGated`, `l2norm`).

Layout: `arch/qwen3_5.rs` becomes `arch/qwen3_5/` with submodules
- `mod.rs`         — Config + (still-stub) ForCausalLM
- `linear_attn.rs` — GatedDeltaNet + GatedDeltaNetState
- `rmsnorm.rs`     — Qwen3_5RmsNorm `(1+w)*x`, Qwen3_5RmsNormGated, l2norm

Architecture pieces in this commit:
- Block: in_proj_qkv + in_proj_z + in_proj_b + in_proj_a + out_proj
  (all bias=False); depthwise causal Conv1d (k=4) with state-aware
  prepend; SiLU; per-head reshape; L2norm on q,k.
- Discretisation: g = -exp(A_log) * softplus(a + dt_bias); beta = σ(b).
  All computed in f32 to avoid the -inf underflow in fp16 that the
  reference notes.
- Delta rule (recurrent, per-token):
    state *= exp(g_t)
    kv_mem = state^T · k_t
    delta  = (v_t - kv_mem) * beta_t
    state += outer(k_t, delta)
    out_t  = state^T · q_t
- Output: RMSNormGated(core_attn_out, z) reshape out_proj.

State (`GatedDeltaNetState`) lives inline on the layer:
- conv_state: (B, conv_dim, conv_kernel_size) — left-padded tail.
- recurrent_state: (B, num_v_heads, head_k_dim, head_v_dim) — the
  delta-rule outer-product memory.
Cleared via `clear_kv_cache` at the start of every new request.

Config extended with the qwen3_5-specific fields:
- linear_num_value_heads (48 in Qwen3.6-27B)
- linear_num_key_heads   (16)
- linear_key_head_dim    (128)
- linear_value_head_dim  (128)
- linear_conv_kernel_dim (4)
- hidden_act             ("silu")

Performance note: this is the **recurrent** delta-rule (PyTorch's
`torch_recurrent_gated_delta_rule`), correct for any seq_len but O(L)
prefill. The chunked algorithm (`torch_chunk_gated_delta_rule`,
chunk_size=64) is a follow-up perf optimisation; surface stays the
same.

8 unit tests:
- softplus small/large branches
- l2norm hand-calc + zero-vector stability
- repeat_interleave round-trip
- forward_smoke on tiny dims (4-head fixture) — verifies shape +
  no NaN/Inf propagation through the f32-promotion pipeline. Doesn't
  validate numerical correctness against the Python reference; that
  requires a fixed-weight fixture and is the next step.

cargo clippy CPU + --features cuda both clean; 32 lib tests pass.
The ForCausalLM stub still bails on forward — wrapping
attention/MLP/decoder layer + lm_head is the next sub-stage.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 09:29:52 +03:00
a70f317729 feat(stage-8c): scaffold qwen3_5 (Qwen3.6) — dispatch + stubs + TP gate
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m14s
CI / Test (push) Successful in 4m29s
build-prerelease / Build neuron-blackwell (push) Successful in 3m39s
build-prerelease / Build cortex binary (push) Successful in 4m17s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m31s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 5m1s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m6s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m44s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m14s
Lays the wiring for the top-priority TP-2 target without doing the
substantive architecture work yet. After this commit, attempting to
load a Qwen3.6 (`model_type = "qwen3_5"`) model:
- Passes config.json parse — the real upstream shape (text_config
  wrapper, layer_types, attn_output_gate, head_dim=256, etc.) round-
  trips through a typed Config (unit test included).
- Constructs a placeholder Qwen3_5ForCausalLM, attaches it to a
  ModelArch::Qwen3_5Dense variant, registers it in the loaded set.
- Fails on the first inference forward with a clear "Qwen3-Next
  forward not implemented yet (Stage 8c, TP-2 motivator)" — the
  point where the real architecture work begins.

New layout:
- `harness/arch/` for custom architectures candle-transformers doesn't
  ship. Each architecture is one module: Config + ForCausalLM + impl.
- `harness/arch/qwen3_5.rs` — the scaffold. Heavy doc comments on the
  open work: layer_types dispatch (full_attention vs linear_attention,
  the latter being the hard part with no candle precedent),
  attn_output_gate, text_config nesting, recurrent state lifecycle.
- DENSE_SUPPORTED_MODEL_TYPES adds "qwen3_5"; load_arch_dense gains a
  branch that constructs the stub.

TP-side gate:
- New `check_tp_arch_supported`: even though Llama / Qwen3 MoE pass
  the single-GPU dense check (DENSE_SUPPORTED_MODEL_TYPES), the
  worker pool's `load_dense_shard` reconstructs the config as Qwen3
  on every rank — silently misrouting a non-Qwen3 dense load through
  it would surface as a cryptic per-rank deserialise error.
- TP_SUPPORTED_MODEL_TYPES = ["qwen3"] (cuda-gated). Anything else
  bails *before* the worker pool spawns and NCCL handshake costs are
  paid, with a marker pointing at the `tp_<family>.rs` module a
  contributor would need to add. qwen3_5 specifically lands here
  until its architecture is real.

The naming choice: keep "qwen3_5" from the model's own config.json
rather than mistralrs's "qwen3_next" — the latter ages poorly the
moment Qwen ship another architecture revision.

Unit tests: 2 new for qwen3_5 (config deserialise + dispatch gate);
the previously-rejecting test for qwen3_5 swapped to a fictional
arch so it stays meaningful as the supported set grows. 26 lib tests
pass; cargo clippy CPU + --features cuda both clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 08:58:01 +03:00
c6022aa6b9 feat(stage-8b): Llama + Qwen3 MoE families on the candle harness
All checks were successful
CI / Format (push) Successful in 31s
build-prerelease / Resolve version stamps (push) Successful in 36s
CI / Clippy (push) Successful in 2m6s
build-prerelease / Build neuron-blackwell (push) Successful in 3m50s
build-prerelease / Build cortex binary (push) Successful in 4m54s
CI / Test (push) Successful in 4m58s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 4m43s
build-prerelease / Build neuron-ada (push) Successful in 5m8s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Broadens the single-GPU dense and quantized paths to cover three
non-Qwen3 architectures already shipped by candle-transformers. TP for
these is a separate stage (each family would need its own tp_*.rs
mirroring tp_qwen3.rs).

`ModelArch` gains four variants:
- LlamaDense (boxed — wraps Llama + an inline Cache + the config it
  takes to rebuild the cache, since candle::llama::Cache has no reset)
- LlamaQuantized (candle_transformers::models::quantized_llama)
- Qwen3MoeDense (candle::models::qwen3_moe::ModelForCausalLM)
- Qwen3MoeQuantized (candle::models::quantized_qwen3_moe::GGUFQWenMoE
  — takes an explicit compute dtype; F16 by default for best
  consumer-GPU throughput)

The dispatch is method-based now:
- `ModelArch::forward(&mut self, input, offset) -> Result<Tensor>`
  with a shared `squeeze_to_vocab` normalising shape differences
  (qwen3 returns [B,1,V]; quantized_qwen3 returns [B,V]; new families
  may differ again — the helper handles all of them).
- `ModelArch::clear_kv_cache(&mut self) -> Result<()>`. Llama needs
  a Cache rebuild because its Cache has no in-place reset; the new
  `LlamaDense` wrapper holds the bits needed to do it.

`run_inference` / `run_inference_streaming` collapse to a single
dispatch path: no more per-variant match arms in the hot loop, and
new architectures pick up streaming + non-streaming for free with
zero changes outside `ModelArch`.

DENSE_SUPPORTED_MODEL_TYPES is now ["llama", "qwen3", "qwen3_moe"].
GGUF arch switch grows "qwen3moe" + "llama" branches (qwen3moe with
no underscore matches llama.cpp's general.architecture convention).
Stage 8a's diagnostic auto-reports the new supported set.

The `LlamaDense` variant is boxed because the wrapper's inline Cache
+ Config makes it 544 bytes vs ~300 for everything else
(clippy::large_enum_variant).

Verified: cargo test --workspace passes 66 tests; cargo clippy CPU
and `--features cuda` both clean (the cuda check ran inside the
locally-built `neuron-build-local` container with the math_functions.h
patch applied).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 08:36:22 +03:00
9e31d8deca feat(stage-8a): pre-flight architecture check for dense model loads
Some checks failed
CI / Format (push) Successful in 32s
build-prerelease / Resolve version stamps (push) Successful in 34s
CI / Clippy (push) Successful in 2m21s
CI / Test (push) Successful in 4m27s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m50s
build-prerelease / Build cortex binary (push) Successful in 4m28s
build-prerelease / Package cortex RPM (push) Successful in 1m24s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
A request to load Qwen/Qwen3.6-27B (model_type "qwen3_5") on the
dense path was failing deep inside serde with:
    missing field `vocab_size` at line 140 column 1
…because Qwen3.6 wraps its actual hyperparameters under `text_config`,
so none of `qwen3::Config`'s expected top-level fields are present.
The error gave no hint that the *architecture* was the problem.

`check_dense_config_supported` parses `config.json` as an untyped
JSON Value, inspects `model_type` (with `architectures` as bonus
context), and bails cleanly when it's not in the supported set
(currently `["qwen3"]`). The error names the rejected type, the
supported set, and points at the files a contributor needs to touch
to extend coverage — both the single-process `ModelArch` variants in
`candle.rs` and the TP analogue in `tp_qwen3.rs`.

Wired into both load paths:
- `load_arch_dense` (single-GPU), before the typed deserialize.
- `load_tp`, before spawning the worker pool — TP loads of an
  unsupported arch now fail before NCCL/init costs are paid.

4 unit tests cover the accept/reject/missing-field/malformed cases.
Bonus: makes Stage 8b/8c work easier — adding a new architecture is
now a `DENSE_SUPPORTED_MODEL_TYPES` edit + ModelArch variant + load
branch, with the diagnostic auto-correctly listing the supported set.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 08:27:29 +03:00
b400e8b704 feat(neuron): honour HF_HUB_CACHE / HF_HOME for the candle harness cache
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 31s
build-prerelease / Build neuron-blackwell (push) Successful in 3m39s
build-prerelease / Build cortex binary (push) Successful in 4m17s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Format (push) Successful in 32s
CI / Test (push) Failing after 51s
CI / Clippy (push) Successful in 2m17s
build-prerelease / Build neuron-ampere (push) Successful in 4m58s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-ada (push) Successful in 5m1s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m4s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m37s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s
Resolves the candle harness's HuggingFace cache directory with the
following precedence (first hit wins):

1. Explicit `hf_cache` in `[harness.candle]` from neuron.toml.
2. `HF_HUB_CACHE` env var — the Python `huggingface_hub` convention.
   The Rust hf-hub crate doesn't read this natively, so we bridge here.
3. `HF_HOME` env var (`$HF_HOME/hub` per the canonical layout).
4. None — falls through to hf-hub's own default.

Honouring HF_HUB_CACHE lets a neuron host reuse an existing cache
directory shared with Python tooling or other harnesses on the same
host without per-tool config. The canonical per-host setup is a
systemd drop-in:

    /etc/systemd/system/neuron.service.d/local.conf
    [Service]
    Environment=HF_HUB_CACHE=/archive/hf-cache

neuron.example.toml documents the resolution chain inline.

script/validate-neuron.sh: bump LOAD_TIMEOUT from 600s to 3600s and
expose both load/infer timeouts via env (NEURON_LOAD_TIMEOUT,
NEURON_INFER_TIMEOUT). A Qwen3.6-class dense model is ~54 GB and was
hitting the 10-min ceiling cold-downloading on a residential link.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 07:52:50 +03:00
62ca125a68 chore: keep models.example.toml generic; deploy.sh sync's local models.toml
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 34s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m22s
CI / Test (push) Successful in 4m31s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m28s
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has started running
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
Reverts the previous commit's naming of specific helexa neuron hosts
in the shipped example catalogue (`models.example.toml`) — the example
is supposed to be a generic starting point that any operator copies
and adapts, not a record of one particular fleet's layout.

- `pinned_on` in the TP example uses the placeholder
  `"your-multi-gpu-neuron"`. Other entries keep the model ids
  (since those are HuggingFace-canonical, not fleet-specific).
- New `models.toml` at repo root holds the helexa-fleet catalogue
  (beast / benjy / quadbrat). Added to `.gitignore` alongside
  `cortex.toml` — both are operator-owned, gitignored, RPM-marked
  `%config(noreplace)`, and synced by `deploy.sh`.
- `deploy.sh` now rsync's `models.toml` to `/etc/cortex/models.toml`
  on the gateway host on the same lifecycle as `cortex.toml`. Skips
  cleanly when no local file exists, so users without a catalogue
  aren't surprised by silent overwrites.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 07:47:08 +03:00
735945ee81 feat(cortex): unified /v1/models — catalogue × topology feasibility + cold-load
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 45s
CI / Format (push) Successful in 48s
CI / Clippy (push) Successful in 2m12s
CI / Test (push) Successful in 4m42s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 5m10s
build-prerelease / Build neuron-blackwell (push) Successful in 3m35s
build-prerelease / Package cortex RPM (push) Successful in 1m19s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
Realises [project-unified-models-endpoint]: cortex now surfaces every
model the operator has provisioned in the catalogue, transparently
cold-loads on the first request, and routes the request once the load
is done — without per-node configuration or client awareness of which
neuron hosts what.

cortex-core changes:
- NodeState gains `discovery: Option<DiscoveryResponse>` — populated
  once per neuron on first successful poll, cached forever after
  (topology is invariant for a neuron process).
- ModelProfile gains `is_feasible_on(neuron, devices)` with the
  pinned_on / min_devices / min_device_vram_mb logic + 5 unit tests.
- CortexModelEntry expanded with OpenAI-compatible (`id`, `object`,
  `created`, `owned_by`) plus helexa-specific extension fields
  (`loaded`, `feasible_on`, `locations`).

cortex-gateway changes:
- poller.rs: `maybe_poll_discovery` fetches `GET /discovery` once per
  neuron and caches on NodeState.
- handlers.rs::list_models rewritten as union of (catalogue × topology
  feasibility) + (currently loaded somewhere). Catalogue-defined models
  surface even when not yet loaded.
- router.rs::resolve gains priority 3 (catalogue cold-load):
    1. loaded somewhere → route there
    2. unloaded somewhere → route + lazy load via neuron
    3. in catalogue → pick feasible neuron, POST /models/load, wait,
       route. Cache the new entry locally so subsequent requests skip
       the poll wait.
    4. else 404
- pick_feasible_neuron prefers pinned_on neurons, falls back to any
  feasible one (stable by name).
- profile_to_spec translates ModelProfile → ModelSpec, picking devices
  by VRAM floor and setting tensor_parallel = min_devices for multi-
  device profiles.
- "already loaded" responses from neuron are tolerated (two concurrent
  requests racing the same cold-load is a benign outcome).

models.example.toml rewritten to reflect the canonical helexa fleet
(beast = 2x RTX 5090, benjy = RTX 4090, quadbrat = RTX 3060) with a
working TP example (Qwen3.6-27B pinned on beast) plus single-GPU
profiles for the smaller models.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 07:39:04 +03:00
f72dee094f feat(tp): Stage 7c-i — streaming SSE through TP
Some checks failed
build-prerelease / Package cortex RPM (push) Blocked by required conditions
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 37s
CI / Clippy (push) Successful in 2m12s
CI / Test (push) Successful in 5m3s
build-prerelease / Build neuron-blackwell (push) Successful in 3m39s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 5m7s
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
build-prerelease / Build neuron-ampere (push) Has been cancelled
`chat_completion_stream` no longer returns an error for TP loads. The
new `chat_completion_tp_stream` mirrors the non-streaming TP path
(clear_kv_cache, prefill, sample, decode loop) but emits one
`ChatCompletionChunk` per generated token over an mpsc channel so the
handler can write a streaming SSE response.

Unlike the single-GPU streaming path (which runs candle's forward
inside `spawn_blocking` and uses `blocking_send`), the TP loop is
itself async — every `pool.generate_step` already awaits the leader's
own spawn_blocking forward plus every worker's recv_only. So the
orchestration runs as a plain `tokio::spawn` task using `Sender::send`.

The shared `emit_chunk` helper tracks the cumulative decoded prefix and
emits the delta — same UTF-8-safe BPE boundary handling as the
single-GPU streaming path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 07:32:46 +03:00
d46d8d4f6c feat(tp): Stage 7b-iv — RPC + orchestration for TP load/inference
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Wires the in-flight TP machinery (Stage 7a workers, 7b-iii sharded
Qwen3) end to end so a non-streaming chat completion can run across
multiple GPUs via NCCL.

RPC additions (tp/rpc.rs):
- LoadDenseShard{model_id, config_json, safetensors_paths}
- GenerateStep{model_id, tokens, offset}
- ClearKvCache{model_id}
- UnloadModel{model_id}
- LoadDenseShardOk / GenerateStepOk / KvCacheCleared / Unloaded

Worker side (tp/worker.rs):
- WorkerState gains a `models: HashMap<String, TpQwen3ForCausalLM>`
  keyed by model_id. LoadDenseShard mmaps safetensors via
  ShardedVarBuilder (only this rank's slice materialises), builds the
  TP model with the rank's NCCL Comm cloned from NcclState.
- GenerateStep runs the rank-local forward; the resulting logits are
  dropped (only the leader's are used for sampling). The forward's
  value here is the NCCL collectives inside the row-parallel layers
  letting the leader's rank-0 forward make progress.

Pool side (tp/mod.rs):
- WorkerPool::load_dense_shard fans LoadDenseShard out to every worker,
  builds rank 0's shard on the leader via spawn_blocking with a fresh
  SendComm wrapper at the move boundary (Comm is !Send at the type
  level), collects per-rank LoadDenseShardOk. Returns the leader's
  Arc<Mutex<TpQwen3ForCausalLM>>.
- WorkerPool::generate_step fans GenerateStep out, runs the leader's
  rank-0 forward in spawn_blocking (the AllReduce CustomOps inside
  row-parallel layers block until every worker issues the matching
  collective), returns the leader's last-position logits Tensor.
- WorkerPool::clear_kv_cache + unload_model follow the same pattern.

NcclState refactor (tp/nccl_state.rs):
- comm field becomes Option<Arc<Comm>> (was Option<Comm>) so callers
  can share a clone with TpQwen3ForCausalLM::load.
- new `comm()` accessor + `SendComm` wrapper for spawn_blocking moves.
- single allow(clippy::arc_with_non_send_sync) at the canonical
  construction site (Comm is !Send by type but the runtime invariant
  is enforced by SendComm + the pool's Mutex).

Harness side (candle.rs):
- LoadedHandle enum (Single | Tp) replaces the bare Arc<LoadedModel>
  in the harness's registry. list_models / unload_model /
  inference_endpoint walk the enum uniformly.
- TpLoadedModel holds the pool + leader_model + tokenizer + devices.
- load_model dispatches on `spec.tensor_parallel > 1` to a new
  cuda-gated load_tp path: resolve dense files via hf-hub, spawn the
  pool, init_nccl, load_dense_shard.
- chat_completion branches on the handle variant. The TP path mirrors
  run_inference: clear_kv_cache, prefill, sample, decode loop,
  detokenize. Acquires the pool Mutex for the whole request.
- Streaming through TP is deferred to Stage 7c (returns Other(err)).

Script (script/validate-neuron.sh):
- 4th positional arg `tp_size` (default 1). When >1, switches to the
  dense path (tp + GGUF is mutually exclusive — bails) and adds
  `tensor_parallel` + `devices` to the load payload. NEURON_DEVICES
  env overrides the default 0..N-1 device list.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 06:38:33 +03:00
9b8bd146f6 feat(tp): --tp-smoke CLI subcommand + remote validation script
All checks were successful
CI / Format (push) Successful in 36s
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Clippy (push) Successful in 2m19s
CI / Test (push) Successful in 4m32s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m43s
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m16s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 4m56s
build-prerelease / Build neuron-ada (push) Successful in 5m1s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m39s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 59s
Adds a one-shot diagnostic that exercises the lower half of the TP
stack — WorkerPool::spawn, init_nccl, nccl_sanity_check — in isolation
from model load and inference. Runs N-1 worker subprocesses (rank 0
stays in this process), joins them in an NCCL communicator on the
specified CUDA devices, all_reduces a sentinel 1u32 per rank, verifies
the observed_sum equals world_size on every rank, then shuts down.

Output is `status=ok` on stdout (plus key=value lines for tp_size and
cuda_devices) when every check passes, non-zero exit + tracing on
stderr otherwise. The smoke command is diagnostic-only and not exposed
through the daemon HTTP API.

script/tp-smoke.sh wraps it with an ssh invocation against a fleet
host (default beast — the only host with 2 GPUs) and asserts the
status line, mirroring the validate-neuron.sh ergonomics.

This is step 1 of the TP test plan. A failure here means TP cannot
work on the host at all; step 2 (Stage 7b-iv) wires real model load
and inference through the same WorkerPool primitives.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 19:40:25 +03:00
96d8755245 fix(tp): add half dep + drop double-wrapped .w() on CudaDevice::alloc
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 37s
CI / Clippy (push) Successful in 2m17s
CI / Test (push) Successful in 4m50s
build-prerelease / Build neuron-blackwell (push) Successful in 3m36s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m32s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Build neuron-ampere (push) Successful in 5m13s
build-prerelease / Build neuron-ada (push) Successful in 4m42s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 3m0s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m39s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m12s
Two follow-up cuda-only fixes surfaced by `cargo build --features cuda`
inside the cuda-13.0 runner container:

1. `half::{bf16, f16}` was an undeclared dep. Added `half = "2.5"`
   (matching candle-core's pinned major) under the cuda feature flag.
2. `dev.alloc::<T>(n)` already returns `candle_core::Result` (it calls
   `.w()` internally on the cudarc error). Calling `.w()?` on top of
   that needs `From<candle_core::Error> for CudaError`, which doesn't
   exist — collapse to `?`. Removed the now-unused
   `cuda_backend::WrapErr` import.

Verified by `cargo build -p neuron --features cuda` and
`cargo clippy -p neuron --all-targets --features cuda -- -D warnings`
inside `git.lair.cafe/gongfoo/runner-cuda-13.0` with the local
glibc/CUDA-13.0 math_functions.h noexcept patch. CPU clippy/tests stay
green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 19:11:59 +03:00
12549c9aed fix(tp): import BackendStorage trait for CudaStorage methods
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 32s
CI / Format (push) Successful in 37s
CI / Clippy (push) Successful in 3m9s
CI / Test (push) Successful in 4m28s
build-prerelease / Build neuron-blackwell (push) Failing after 3m41s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m32s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Failing after 4m45s
build-prerelease / Build neuron-ada (push) Failing after 5m13s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Stage 7b-iii (1/2) introduced AllReduce with `s.device()` and
`s.dtype()` calls on `&CudaStorage`. Both come from the
`candle_core::backend::BackendStorage` trait, which wasn't imported —
fine on CPU builds (the cuda_fwd block was cfg-gated out) but the
prerelease cuda build hit E0599.

Also drop the unused `cudarc::driver::DeviceSlice` import inside
cuda_fwd — `CudaSlice::len()` is an inherent method on cudarc 0.19,
not a trait method.

Caught by run 2894 (build-neuron-{blackwell,ampere}); CPU clippy +
tests stay green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 18:32:05 +03:00
46527d7804 feat(tp): TP-aware Qwen3 dense model (Stage 7b-iii 2/2)
Mirrors candle_transformers::models::qwen3 structurally with column-
parallel q/k/v + gate/up projections, row-parallel o + down projections,
and replicated embedding/norms/lm_head. Per-rank head counts come from
dividing num_attention_heads / num_key_value_heads by world_size at load
time; intermediate_size split likewise. Load bails on any non-divisible
shape — the safetensors slice would lose data otherwise.

KV cache holds the rank-local slice since K/V come out of column-parallel
projections; no cache resharding across ranks. Causal mask is computed
on rank 0 shape and broadcasts over the head dim so per-rank H differs
without rework.

Replicated tensors (embedding, all RmsNorms, untied lm_head) load via
vb.get(shape, name), which uses the default Shard { world_size: 1 } and
falls through to the unsharded backend path on ShardedSafeTensors.

The cuda / non-cuda load splits track the existing tp_linear pattern:
RowParallelLinear takes an Arc<Comm> only under cuda, and the higher-
level composers (TpQwen3MLP, TpQwen3Attention, TpDecoderLayer,
TpQwen3Model, TpQwen3ForCausalLM) thread it through accordingly.

7b-iv wires RPC + dispatch in CandleHarness::load_model.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 18:24:20 +03:00
8d3194f992 Stage 7b-iii (1/2): AllReduce CustomOp + ShardedVarBuilder-backed TP linears
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m16s
build-prerelease / Build neuron-blackwell (push) Failing after 3m19s
CI / Test (push) Successful in 4m26s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Failing after 4m58s
build-prerelease / Build neuron-ada (push) Failing after 4m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Ports the canonical
candle-examples/examples/llama_multiprocess/model.rs pattern into
the harness. Two new files, one deletion:

- harness/tp/all_reduce.rs — AllReduce wraps Arc<cudarc::nccl::Comm>
  and implements candle's CustomOp1 trait. cuda_fwd extracts the
  rank's CudaSlice<dtype> from a CudaStorage, asserts the input is
  contiguous (a strided activation hitting all_reduce is almost
  always a model construction bug), allocates an output CudaSlice
  on the same device, calls Comm::all_reduce(Sum), and wraps the
  result back as a CudaStorage. Handles BF16, F16, F32. NcclError
  surfaces via {e:?} (no Display impl in cudarc 0.19.x). Send/Sync
  hand-impl'd with the same NCCL-thread-safety caveat candle's
  example documents.

- harness/tp/tp_linear.rs — ColumnParallelLinear and
  RowParallelLinear, both built on candle's ShardedVarBuilder +
  Shard hints. `vb.get_with_hints((), "weight", shard(dim, rank, ws))`
  reads JUST the rank's slice from the safetensors view; no full-
  tensor host materialisation. ColumnParallel.forward is a plain
  local matmul (output is naturally sharded). RowParallel.forward =
  local matmul + apply_op1_no_bwd(&self.all_reduce). On CPU /
  world_size == 1, the AllReduce is skipped and the partial output
  is returned as-is. Both layers are no-bias — every Qwen3-family
  target sets attention_bias=false; bias-aware sharding is a
  future-model concern.

- Deletes harness/tp/sharded_linear.rs from 7b-ii. That commit's
  hand-rolled "load full + narrow" approach was useful exploration
  but candle's ShardedVarBuilder does the same work without
  materialising the full tensor on host. The 5 unit tests there
  verified the slicing math against an unsharded reference; that
  math now lives inside candle and is covered by candle's own tests.

Next (7b-iii 2/2): TpQwen3Attention + TpQwen3MLP composing the
column/row pair, then a TpQwen3Model that runs the full forward.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 18:14:54 +03:00
5436af9c73 fix(neuron/candle): dense Qwen3 returns rank-3 logits, double-squeeze
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 33s
CI / Format (push) Successful in 38s
CI / Clippy (push) Successful in 2m19s
build-prerelease / Build neuron-blackwell (push) Successful in 3m32s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m16s
build-prerelease / Package cortex RPM (push) Successful in 1m18s
build-prerelease / Build neuron-ampere (push) Successful in 4m55s
build-prerelease / Build neuron-ada (push) Successful in 5m11s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m52s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m35s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Caught by live validation against Qwen/Qwen3-1.7B on beast:
  HTTP 500 "unexpected rank, expected: 1, got: 2 ([1, 151936])"

Candle's qwen3::ModelForCausalLM::forward returns shape [B, 1, V]
(no final squeeze) while quantized_qwen3::ModelWeights::forward
returns [B, V] (with squeeze(1) at the end). My match arms applied
a single squeeze(0) uniformly, which is correct for the quantized
[1, V] → [V] but leaves the dense at [1, V] → which then trips
apply_repeat_penalty::to_vec1() expecting rank 1.

Dense match arms now strip both batch and seq dims:
  model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?

Also fixes validate-neuron.sh's `${3:-Q4_K_M}` → `${3-Q4_K_M}`
(no colon) so passing an explicit empty third arg now drives the
dense path instead of falling back to Q4_K_M.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 17:49:43 +03:00
8e882c0757 fix(neuron/tp): NcclError {e:?} + cudarc 0.19 deprecation cleanup
All checks were successful
CI / Format (push) Successful in 38s
build-prerelease / Resolve version stamps (push) Successful in 40s
CI / Clippy (push) Successful in 2m15s
build-prerelease / Build neuron-blackwell (push) Successful in 3m35s
CI / Test (push) Successful in 5m0s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m51s
build-prerelease / Package cortex RPM (push) Successful in 1m27s
build-prerelease / Build neuron-ampere (push) Successful in 4m55s
build-prerelease / Build neuron-ada (push) Successful in 4m57s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m37s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m2s
Two cuda-feature-only build errors only the CI runner catches:

1. cudarc::nccl::NcclError doesn't impl Display in 0.19.x, so the
   `format!("...: {e}")` map_err calls fail to compile when the cuda
   feature actually wires them up. Switch every NcclError-typed `{e}`
   in nccl_state.rs to `{e:?}` — surfaces variant + ncclResult code
   in the same diagnostic shape just via Debug instead of Display.
2. cudarc::CudaStream::memcpy_stod / memcpy_dtov are deprecated in
   0.19.7 in favour of clone_htod / clone_dtoh. The replacements
   take/return the same types, so the swap is mechanical.

Dev box can't compile with --features cuda (no nvcc), so these only
surface in the build-prerelease CUDA matrix jobs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 17:24:13 +03:00
93421f48e2 Stage 7b-ii: ColumnParallel + RowParallel sharded linear primitives
Some checks failed
build-prerelease / Resolve version stamps (push) Successful in 30s
CI / Format (push) Successful in 31s
CI / Clippy (push) Failing after 49s
build-prerelease / Build neuron-blackwell (push) Failing after 3m29s
build-prerelease / Build cortex binary (push) Successful in 4m41s
CI / Test (push) Successful in 5m6s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-ampere (push) Failing after 5m1s
build-prerelease / Build neuron-ada (push) Failing after 4m53s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Adds harness/tp/sharded_linear.rs with ShardedLinear — a Megatron-LM
style sharded wrapper over candle_nn::Linear. Two constructors:

- load_column: splits the output dimension. Each rank holds rows
  [r*out/N .. (r+1)*out/N] of the weight, plus its slice of the bias.
  Forward = local matmul; output is naturally sharded; downstream
  consumer either accepts the shard (next layer is column-parallel)
  or merges via all-gather later.
- load_row: splits the input dimension. Each rank holds cols
  [r*in/N .. (r+1)*in/N] of the weight; bias lives only on rank 0
  so the post-all_reduce sum carries it exactly once. Forward
  produces a partial output that the caller reduces via NCCL.

Both constructors bail with a clear error when divisibility doesn't
hold — the precondition mistral.rs's first qwen3-next-tp commit
made explicit. The path included in the error is the VarBuilder
prefix, so the operator sees exactly which projection failed
("column-parallel 'model.layers.0.self_attn.q_proj': out_features=...").

5 unit tests on CPU verify the math against an unsharded reference:
- column shard produces the expected slice of the full matmul
- row partials sum to the unsharded result
- row bias appears only on rank 0
- divisibility violations bail (column + row)

forward_with_comm() is stubbed for row-parallel (CUDA-only) — wiring
the actual cudarc::nccl all_reduce against candle's Tensor lands in
7b-iii alongside the model assembly, where the model holds the Comm
in scope. ColumnParallel's forward_with_comm just delegates to the
local matmul (no collective needed).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 17:07:19 +03:00
05e15f3597 Stage 7b-i: dense safetensors Qwen3 load path
Some checks failed
build-prerelease / Build cortex binary (push) Blocked by required conditions
CI / Test (push) Waiting to run
CI / Format (push) Successful in 43s
build-prerelease / Resolve version stamps (push) Successful in 44s
CI / Clippy (push) Successful in 2m4s
build-prerelease / Build neuron-ampere (push) Has been cancelled
build-prerelease / Build neuron-ada (push) Has been cancelled
build-prerelease / Package cortex RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ada RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been cancelled
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been cancelled
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been cancelled
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
build-prerelease / Build neuron-blackwell (push) Has been cancelled
Adds the bf16/fp16 safetensors path alongside the existing GGUF
quantized one. The harness now dispatches by ModelSpec.quant:
- Some(_) → GGUF (pre-quantized, single-GPU only path, unchanged).
- None    → safetensors dense (new).

The dense path uses candle-transformers::models::qwen3::ModelForCausalLM
verbatim, fed via VarBuilder::from_mmaped_safetensors over the files
listed in `model.safetensors.index.json` (sharded layout) or the
single `model.safetensors` fallback. dtype is bf16 to match the
canonical Qwen3 HF distribution dtype. tokenizer.json is fetched from
the same repo (no -GGUF suffix to strip).

ModelArch gains a Qwen3Dense variant; the forward signature mirrors
QuantizedQwen3Weights (same `forward(&Tensor, offset)` → last-position
logits), so run_inference / run_inference_streaming just add a parallel
match arm — no shape changes downstream.

This is the foundation 7b-ii (ColumnParallel/RowParallel) builds on:
because the source is dense safetensors that can be byte-sliced per
rank, the TP work avoids the GGUF super-block alignment problem
entirely. Vanilla GGUF inference keeps working unchanged.

validate-neuron.sh learns the dense path: pass an empty third arg
(quant) and the script omits the `quant` field from the load
payload, triggering the dense dispatch. Example:
  script/validate-neuron.sh beast.hanzalova.internal Qwen/Qwen3-0.6B ''

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 17:03:59 +03:00
da068ded6d Stage 7a-ii: real NCCL handshake behind the worker pool
Some checks failed
CI / Format (push) Failing after 38s
build-prerelease / Resolve version stamps (push) Successful in 42s
CI / Clippy (push) Successful in 2m18s
build-prerelease / Build neuron-blackwell (push) Failing after 3m33s
CI / Test (push) Successful in 4m27s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m31s
build-prerelease / Package cortex RPM (push) Successful in 1m21s
build-prerelease / Build neuron-ampere (push) Failing after 4m19s
build-prerelease / Build neuron-ada (push) Failing after 4m56s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Wires cudarc::nccl into the TP worker lifecycle introduced in 7a-i.
With --features cuda the leader and its workers now establish a live
NCCL communicator end-to-end; without the feature the same code paths
return Error{kind="cuda_feature_not_enabled"} so a misconfigured
build is obvious instead of silently no-op.

NCCL state machine (harness/tp/nccl_state.rs) is shared between the
worker process and the leader's pool:
- generate_comm_id_hex() mints an Id::new() on the leader.
- NcclState::init parses 256 hex chars → [c_char; 128] → Id::uninit,
  opens a CudaContext on the configured device, calls Comm::from_rank
  with the supplied (rank, world_size, id). NCCL blocks until every
  rank has joined.
- NcclState::sanity_check runs one all_reduce(1u32, Sum); the leader
  asserts every rank reports observed_sum == world_size.
- NCCL handles serialised under Mutex; unsafe impl Send/Sync gates
  the Comm across spawn_blocking boundaries (NCCL is move-safe; only
  concurrent op issuance is unsafe).

WorkerPool::init_nccl orchestrates the rendezvous:
1. Write Init { comm_id } to every worker's stdin (no await yet).
2. Leader rank 0 calls its own Comm::from_rank in spawn_blocking,
   concurrently with workers.
3. NCCL handshake completes for all ranks simultaneously.
4. Leader collects InitOk responses.
WorkerPool::nccl_sanity_check follows the same pattern over
all_reduce, validating world_size == observed_sum on every rank.

Worker.send_only / Worker.recv_only split out from the previous
monolithic Worker.request so the leader can interleave its own NCCL
work with the worker calls — required because NCCL blocks during
init.

Tests:
- 4 hex roundtrip unit tests for the wire encoding.
- The 7a-i "not implemented" expectation now reads
  "cuda_feature_not_enabled" on the local dev box (no CUDA), or
  accepts InitOk on a cuda-built test binary.
- New cuda-integration test in tp_worker_lifecycle_cuda.rs covers
  the real init + sanity round-trip; gated on the cuda-integration
  feature so default CI doesn't try to NCCL.

Verifiable on beast (2× RTX 5090):
  cargo test -p neuron --features cuda-integration \
        --test tp_worker_lifecycle_cuda

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 16:40:01 +03:00
2a7ede0232 Stage 7a-i: TP worker lifecycle scaffolding
All checks were successful
CI / Format (push) Successful in 36s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Clippy (push) Successful in 2m12s
CI / Test (push) Successful in 4m25s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m49s
build-prerelease / Build cortex binary (push) Successful in 4m22s
build-prerelease / Package cortex RPM (push) Successful in 1m23s
build-prerelease / Build neuron-ampere (push) Successful in 5m9s
build-prerelease / Build neuron-ada (push) Successful in 4m59s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m53s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m59s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m38s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m8s
Leader → worker process plumbing for tensor parallelism. The neuron
binary picks up two modes: default (the existing daemon, axum + HTTP)
and `--worker` (a bare RPC loop driven over stdin/stdout). The leader
spawns one worker per non-zero NCCL rank via tokio::process::Command
on the same binary path (production: /proc/self/exe; tests:
env!("CARGO_BIN_EXE_neuron")) and talks to each over newline-
delimited JSON.

Protocol (harness/tp/rpc.rs) is serde-tagged from the start —
WorkerRequest::{Ping, Init, NcclSanityCheck, Shutdown} and
WorkerResponse::{Pong, InitOk, NcclSanityResult, Bye, Error}, both
`#[serde(tag = "op", rename_all = "snake_case")]`. Adding ops in 7b/7c
is purely additive; unknown ops on the wire fail to parse (verified
in unit tests).

7a-i scope:
- WorkerPool::spawn(binary, world_size, devices) forks ranks 1..N as
  subprocesses, captures stdin/stdout, kills on drop.
- ping_all() round-trips a Ping to every worker and validates the
  returned rank.
- shutdown() sends Shutdown to each worker, awaits Bye, reaps.
- Worker mode: parse Ping/Shutdown, return Pong/Bye; Init and
  NcclSanityCheck return Error{kind="not_implemented_7a_i"} so a 7a-ii
  binary speaking the same wire is a drop-in replacement (the kind
  field signals "real NCCL lands in the next commit").
- CandleHarness::load_model refuses tensor_parallel > 1 with a clear
  message until 7b is in.

Three integration tests in tests/tp_worker_lifecycle.rs cover spawn/
ping/shutdown for 2- and 3-worker pools, plus the
not_implemented_7a_i contract test for Init. Seven rpc serde unit
tests assert the wire shape (op tags, field names, unknown-op
rejection). All pass on the dev host; no CUDA required.

Stage 7a-ii (next): the real NCCL Comm::from_rank wiring behind the
existing Init/NcclSanityCheck op surface, CUDA-gated. Verifiable on
beast's 2×5090.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 15:53:00 +03:00
18ae3c30ee post-validation cleanup: cuDNN runtime + repetition penalty
All checks were successful
CI / Format (push) Successful in 34s
build-prerelease / Resolve version stamps (push) Successful in 35s
CI / Clippy (push) Successful in 2m17s
CI / Test (push) Successful in 4m16s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m28s
build-prerelease / Build neuron-blackwell (push) Successful in 3m42s
build-prerelease / Package cortex RPM (push) Successful in 1m25s
build-prerelease / Build neuron-ampere (push) Successful in 4m27s
build-prerelease / Build neuron-ada (push) Successful in 4m51s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m40s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 6m52s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 2m32s
Two followups from the live single-GPU validation pass.

1. deploy.sh now ensures libcudnn.so.9 is available on each neuron
   host before installing/upgrading the package. Probes ldconfig first
   so hosts with a manual (tar/runfile) cuDNN install are untouched,
   then adds NVIDIA's RHEL9 CUDA repo (the Fedora 43 CUDA repo doesn't
   ship cuDNN; only the RHEL9 one does) and installs libcudnn9-cuda-13.
   benjy hit "cannot open shared object file: libcudnn.so.9" during
   validation; this prevents that recurring.

2. candle.rs applies a 1.1 repetition penalty over the last 64
   generated tokens before sampling, in both the non-streaming
   chat_completion path and the streaming chat_completion_stream
   path. Without it small Q4_K_M models degenerate into "Wait, no,
   no..." loops once they hit a confident-but-wrong path; with it
   sampling stays coherent. Defaults match mistral.rs and llama.cpp;
   exposing the value via the OpenAI request (frequency/presence
   penalty mapping) is Stage 8 territory.

Both routes through a new sample_with_penalty() helper so future
sampling tweaks land in one place.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 14:48:08 +03:00
1a0400131e fix(deploy): use dnf upgrade for stale installs, install only when absent
All checks were successful
CI / Format (push) Successful in 35s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Clippy (push) Successful in 2m27s
CI / Test (push) Successful in 4m30s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m29s
build-prerelease / Build cortex binary (push) Successful in 4m32s
build-prerelease / Package cortex RPM (push) Successful in 1m20s
build-prerelease / Build neuron-ampere (push) Successful in 5m15s
build-prerelease / Build neuron-ada (push) Successful in 4m51s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m48s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m47s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m38s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 57s
dnf5's `dnf install <pkg>` is a no-op when the package is already
installed at ANY version — it does NOT auto-upgrade to the latest
available. The deploy script's install branch was therefore silently
leaving hosts on older builds even though needs_update correctly
reported an upgrade was available.

Add an is_installed() probe and an install_or_upgrade() helper that
picks the right verb: `dnf install` when fresh, `dnf upgrade` when
stale. Captured combined-stream output is exposed via __DNF_OUTPUT__
for the existing failure-diagnostic path.

Verified end-to-end against the live fleet: hanzalova/beast/benjy/
quadbrat all upgraded cleanly from prior prerelease NVRs to
0.1.16-0.1.20260519134302.git1866b99.fc43, validation script returned
"Paris" from all three neurons.

Followup (not in this commit): all hosts running helexa-neuron-*
need libcudnn.so.9 available at runtime. Currently:
  - quadbrat: libcudnn9-cuda-13 RPM (rhel9 CUDA repo)
  - beast:    /usr/lib64/libcudnn.so.9 (manual install)
  - benjy:    needed rhel9 CUDA repo added + libcudnn9-cuda-13 installed
              as part of this validation pass.
The spec currently excludes cuDNN from auto-detected deps. Should
add a Recommends:libcudnn9-cuda-13 (soft) and ensure the rhel9 CUDA
repo is configured on each neuron host, similar to how ensure_lair_repo
handles the unstable channel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 14:10:48 +03:00
1866b99a89 fix(validate-neuron): jq for JSON, say→stderr, sane max_tokens
All checks were successful
CI / Format (push) Successful in 35s
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Clippy (push) Successful in 2m13s
CI / Test (push) Successful in 4m22s
build-prerelease / Build neuron-blackwell (push) Successful in 3m25s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build cortex binary (push) Successful in 4m21s
build-prerelease / Package cortex RPM (push) Successful in 1m17s
build-prerelease / Build neuron-ampere (push) Successful in 4m39s
build-prerelease / Build neuron-ada (push) Successful in 4m57s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m50s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m58s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m34s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m3s
Three real bugs caught while exercising the script end-to-end against
the live quadbrat node:

1. say() printed status to stdout. Inside run_probe(), the
   "POST /v1/chat/completions (probe: ...)" line was being captured
   by `raw=$(run_probe)` along with the JSON body, so jq saw
   "[host] POST..." as the first line and choked at column 29 with
   "Invalid numeric literal" (it tried to parse the `[` as the start
   of a JSON array). Redirect say() to stderr so command
   substitutions capture only the intended return value.

2. The pretty-print step `echo "${raw}" | yq -r '.'` re-emitted the
   JSON as YAML, which fails on response content that looks like YAML
   markers (chatcmpl ids that parse as aliases, escaped quotes inside
   <think>...</think> blocks). Drop the pretty-print; just echo the
   raw JSON.

3. JSON response parsing now uses jq (always JSON) instead of yq
   (parses input as YAML by default). yq remains in use only for the
   genuinely-YAML asset/manifest.yml elsewhere.

4. max_tokens bumped 32 → 256. Qwen3 prepends a <think>...</think>
   reasoning block before its final answer when the chat template
   enables thinking mode, and that eats most of a small budget — the
   "Paris" answer was being truncated mid-thought. 256 leaves enough
   room for both.

Verified pipeline end-to-end on quadbrat (RTX 3060, helexa-neuron-ampere
git602e8e1): /health OK → /models/load (unsloth/Qwen3-0.6B-GGUF Q4_K_M)
→ /v1/chat/completions → response content contains "Paris".

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 13:43:02 +03:00
60176e7c2e ci: monotonic prerelease versions + serialize CI on shared runner
Two CI hygiene fixes uncovered while validating against the live fleet.

1. Same-day prerelease packages were being ordered by RPM-vercmp's
   alpha-vs-digit precedence on the git SHA fragment, not by commit
   chronology. With release stamps like "0.1.${YYYYMMDD}git${SHA}",
   two commits on the same day produce the same numeric prefix and
   rpmvercmp falls back to comparing the alphanumeric SHA suffixes,
   where digit-leading SHAs are ranked above alpha-leading ones —
   completely unrelated to which commit landed first. Verified with
   rpmdev-vercmp:
     gitabc1234 < gitdef5678   (old scheme — purely lexicographic)
   Bumping the timestamp prefix to second-precision (%Y%m%d%H%M%S)
   makes the numeric prefix strictly monotonic for any chronologically-
   ordered commits, so the SHA fragment becomes a debug identifier
   only — never participates in version ordering.

2. ci.yml and build-prerelease.yml both target the `rust` runner label
   and both auto-trigger on push to main. The act-based runner reuses
   /root/.cache/act/<hash>/hostexecutor/ across concurrent jobs, so
   ci.yml's clippy and build-prerelease.yml's build-cortex were racing
   each other's checkout/cleanup steps and corrupting in-flight
   compile artifacts. Real fix is in gongfoo; workflow-level workaround
   is a shared concurrency group with cancel-in-progress=false so the
   two workflows queue sequentially on the same ref.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 13:36:53 +03:00
49 changed files with 10157 additions and 219 deletions

View File

@@ -8,8 +8,13 @@ name: build-prerelease
# Optionally provide a `ref` to build from a non-default branch. # Optionally provide a `ref` to build from a non-default branch.
# #
# The published packages are versioned as e.g. # The published packages are versioned as e.g.
# helexa-neuron-blackwell-0.1.16-0.1.20260518gitabcdef0.fc43.x86_64 # helexa-neuron-blackwell-0.1.16-0.1.20260518T140530.gitabcdef0.fc43.x86_64
# so they sort BELOW the eventual 0.1.16-1 stable release. # ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
# commit time (s) commit sha
# so they sort BELOW the eventual 0.1.16-1 stable release, and so two
# commits on the same day are still strictly ordered by their commit
# timestamps (rather than by RPM-vercmp's alpha-vs-digit precedence
# on the SHA fragment).
on: on:
# Auto-build on every push to main so the unstable channel tracks # Auto-build on every push to main so the unstable channel tracks
@@ -25,10 +30,14 @@ on:
default: "" default: ""
concurrency: concurrency:
# Coalesce on branch+event so successive pushes don't pile up; the # Share the group with ci.yml so the two workflows can't run
# latest push wins. # concurrently on the same `rust` runner (act reuses the workspace
group: prerelease-build-${{ github.ref }} # cache and races destroy each other's build files mid-compile).
cancel-in-progress: true # cancel-in-progress=false → workflows queue; if a newer push lands,
# the older run is still picked up by ci.yml's own ref-keyed
# concurrency (same group, queued).
group: cortex-runner-pool-${{ github.ref }}
cancel-in-progress: false
env: env:
CARGO_INCREMENTAL: "0" CARGO_INCREMENTAL: "0"
@@ -41,7 +50,7 @@ jobs:
version: ${{ steps.info.outputs.version }} version: ${{ steps.info.outputs.version }}
release: ${{ steps.info.outputs.release }} release: ${{ steps.info.outputs.release }}
short_sha: ${{ steps.info.outputs.short_sha }} short_sha: ${{ steps.info.outputs.short_sha }}
commit_date: ${{ steps.info.outputs.commit_date }} commit_timestamp: ${{ steps.info.outputs.commit_timestamp }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
@@ -53,13 +62,20 @@ jobs:
set -eux set -eux
VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml) VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml)
SHORT_SHA=$(git rev-parse --short=7 HEAD) SHORT_SHA=$(git rev-parse --short=7 HEAD)
COMMIT_DATE=$(git log -1 --format=%cd --date=format:%Y%m%d HEAD) # Second-precise commit timestamp gives the release stamp a
# Prerelease release stamp sorts before "1" (the stable release). # strictly monotonic numeric prefix. The earlier %Y%m%d-only
RELEASE="0.1.${COMMIT_DATE}git${SHORT_SHA}" # form let same-day builds be ordered by RPM's rpmvercmp
# rules over the SHA, which is non-chronological — e.g.
# "git602e8e1" sorts newer than "gitf9f5fa4" purely because
# rpmvercmp ranks digit-prefixed segments above alpha ones.
# The SHA stays only as a debug identifier; sort order is
# decided entirely by the timestamp.
COMMIT_TIMESTAMP=$(git log -1 --format=%cd --date=format:%Y%m%d%H%M%S HEAD)
RELEASE="0.1.${COMMIT_TIMESTAMP}.git${SHORT_SHA}"
echo "version=${VERSION}" >> "$GITHUB_OUTPUT" echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
echo "release=${RELEASE}" >> "$GITHUB_OUTPUT" echo "release=${RELEASE}" >> "$GITHUB_OUTPUT"
echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT" echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT"
echo "commit_date=${COMMIT_DATE}" >> "$GITHUB_OUTPUT" echo "commit_timestamp=${COMMIT_TIMESTAMP}" >> "$GITHUB_OUTPUT"
build-cortex: build-cortex:
name: Build cortex binary name: Build cortex binary

View File

@@ -7,6 +7,16 @@ on:
pull_request: pull_request:
branches: [main] branches: [main]
# Share a concurrency group with build-prerelease.yml so the two
# workflows don't race on the same `rust` runner workspace (act's
# /root/.cache/act/<hash>/hostexecutor/ is shared across concurrent
# jobs and one job's checkout step nukes another's in-flight build
# files). cancel-in-progress=false → they queue; same-ref pushes
# coalesce per workflow via cancel-in-progress on each.
concurrency:
group: cortex-runner-pool-${{ github.ref }}
cancel-in-progress: false
env: env:
CARGO_INCREMENTAL: "0" CARGO_INCREMENTAL: "0"
RUSTC_WRAPPER: sccache RUSTC_WRAPPER: sccache

2
.gitignore vendored
View File

@@ -4,4 +4,6 @@
.idea/ .idea/
.vscode/ .vscode/
cortex.toml cortex.toml
models.toml
doc/plan/* doc/plan/*
/target-cuda/

5
Cargo.lock generated
View File

@@ -596,6 +596,7 @@ dependencies = [
"tower", "tower",
"tower-http", "tower-http",
"tracing", "tracing",
"url",
"urlencoding", "urlencoding",
] ]
@@ -2113,10 +2114,14 @@ dependencies = [
"candle-transformers", "candle-transformers",
"clap", "clap",
"cortex-core", "cortex-core",
"cudaforge",
"cudarc 0.19.7",
"figment", "figment",
"futures", "futures",
"half",
"hf-hub", "hf-hub",
"reqwest", "reqwest",
"safetensors 0.7.0",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.18", "thiserror 2.0.18",

View File

@@ -74,6 +74,32 @@ install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
%postun %postun
%systemd_postun_with_restart cortex.service %systemd_postun_with_restart cortex.service
%posttrans
# Migration: older cortex packages shipped the firewalld service as
# `helexa-cortex` and (in some build streams) with wrong port numbers
# (9301/9302/9304). Operators who enabled that legacy service in their
# zone end up with the wrong-port override taking precedence over the
# vendor `cortex.xml` now in /usr/lib/firewalld/services/. Clean up the
# stale /etc/ override here and migrate any zone bindings to the new
# service name.
if [ -f /etc/firewalld/services/helexa-cortex.xml ]; then
rm -f /etc/firewalld/services/helexa-cortex.xml
fi
if [ -x /usr/bin/firewall-cmd ] && /usr/bin/firewall-cmd --state >/dev/null 2>&1; then
# Drop the legacy service name from every zone where it was enabled
# and add the new `cortex` service in its place. Operators who never
# ran firewall-cmd against either name see no zone change.
for zone in $(/usr/bin/firewall-cmd --get-active-zones 2>/dev/null \
| awk '!/^[[:space:]]/ {print $1}'); do
if /usr/bin/firewall-cmd --permanent --zone="$zone" --query-service=helexa-cortex >/dev/null 2>&1; then
/usr/bin/firewall-cmd --permanent --zone="$zone" --remove-service=helexa-cortex >/dev/null 2>&1 || :
/usr/bin/firewall-cmd --permanent --zone="$zone" --add-service=cortex >/dev/null 2>&1 || :
fi
done
/usr/bin/firewall-cmd --reload >/dev/null 2>&1 || :
fi
:
%files %files
%license LICENSE %license LICENSE
%doc README.md %doc README.md

View File

@@ -1,5 +1,6 @@
//! Model catalogue — profiles describing how to serve each model. //! Model catalogue — profiles describing how to serve each model.
use crate::discovery::DeviceInfo;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::Path; use std::path::Path;
@@ -64,4 +65,103 @@ impl ModelCatalogue {
.iter() .iter()
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string())) .any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
} }
/// Find a profile by model id.
pub fn get(&self, model_id: &str) -> Option<&ModelProfile> {
self.models.iter().find(|p| p.id == model_id)
}
}
impl ModelProfile {
/// True iff this profile's placement constraints can be satisfied
/// by the named neuron with the given device topology.
///
/// Constraints checked:
/// - `pinned_on`: non-empty → neuron must be on the list.
/// - `min_devices`: neuron must have at least this many devices.
/// - `min_device_vram_mb`: at least `min_devices` of the neuron's
/// devices must each meet this VRAM floor.
pub fn is_feasible_on(&self, neuron_name: &str, devices: &[DeviceInfo]) -> bool {
if !self.pinned_on.is_empty() && !self.pinned_on.iter().any(|n| n == neuron_name) {
return false;
}
if (devices.len() as u32) < self.min_devices {
return false;
}
if let Some(min_vram) = self.min_device_vram_mb {
let big_enough = devices
.iter()
.filter(|d| d.vram_total_mb >= min_vram)
.count() as u32;
if big_enough < self.min_devices {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::DeviceInfo;
fn device(idx: u32, vram_mb: u64) -> DeviceInfo {
DeviceInfo {
index: idx,
name: format!("DEV-{idx}"),
vram_total_mb: vram_mb,
compute_capability: "8.6".into(),
}
}
fn profile() -> ModelProfile {
ModelProfile {
id: "Qwen/Qwen3.6-27B".into(),
harness: "candle".into(),
quant: None,
vram_mb: Some(45_000),
min_devices: 2,
min_device_vram_mb: Some(24_000),
pinned_on: vec![],
}
}
#[test]
fn feasible_when_two_devices_meet_vram_floor() {
let p = profile();
let devices = [device(0, 32_000), device(1, 32_000)];
assert!(p.is_feasible_on("beast", &devices));
}
#[test]
fn infeasible_when_only_one_device() {
let p = profile();
let devices = [device(0, 64_000)];
assert!(!p.is_feasible_on("benjy", &devices));
}
#[test]
fn infeasible_when_one_device_underspec() {
let p = profile();
let devices = [device(0, 32_000), device(1, 12_000)];
assert!(!p.is_feasible_on("mixed", &devices));
}
#[test]
fn pinned_on_excludes_other_neurons() {
let mut p = profile();
p.pinned_on = vec!["beast".into()];
let devices = [device(0, 32_000), device(1, 32_000)];
assert!(p.is_feasible_on("beast", &devices));
assert!(!p.is_feasible_on("benjy", &devices));
}
#[test]
fn no_vram_floor_just_needs_min_devices() {
let mut p = profile();
p.min_device_vram_mb = None;
let devices = [device(0, 1_000), device(1, 1_000)];
assert!(p.is_feasible_on("anywhere", &devices));
}
} }

View File

@@ -1,3 +1,4 @@
use crate::discovery::DiscoveryResponse;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
@@ -13,6 +14,12 @@ pub struct NodeState {
/// Number of load/unload cycles since last process restart. /// Number of load/unload cycles since last process restart.
pub lifecycle_cycles: u32, pub lifecycle_cycles: u32,
pub last_poll: Option<DateTime<Utc>>, pub last_poll: Option<DateTime<Utc>>,
/// Result of the most recent successful `GET /discovery` against
/// this neuron. Cached forever once obtained — device topology is
/// invariant for a given neuron process. `None` until the first
/// successful poll. Used by the router and `/v1/models` to do
/// catalogue × topology feasibility checks.
pub discovery: Option<DiscoveryResponse>,
} }
/// A model registered on a node, with its runtime status. /// A model registered on a node, with its runtime status.
@@ -36,12 +43,32 @@ pub enum ModelStatus {
} }
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint. /// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
/// Includes which node(s) host this model and their status. ///
/// The first four fields (`id`, `object`, `created`, `owned_by`) match
/// OpenAI's `/v1/models` shape verbatim, so existing OpenAI-aware
/// tooling deserialises this without custom code. The remaining fields
/// are helexa-specific extensions — OpenAI clients ignore unknown
/// fields and other consumers can read them for placement / debugging.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CortexModelEntry { pub struct CortexModelEntry {
pub id: String, pub id: String,
/// Always `"model"` per OpenAI's contract.
pub object: String, pub object: String,
/// Which nodes have this model (and their status). /// Unix-second timestamp; cortex stamps this at response time.
pub created: u64,
/// OpenAI's "publisher" field — `"helexa"` for everything we serve.
pub owned_by: String,
/// True if any neuron currently has this model loaded. False for
/// catalogue entries that are feasible but not yet loaded.
pub loaded: bool,
/// Neurons whose discovered topology can satisfy this model's
/// catalogue placement constraints. Empty for models that are
/// loaded somewhere but not present in the catalogue (cortex has
/// no feasibility opinion on those).
pub feasible_on: Vec<String>,
/// Where this model is actually loaded right now. Subset of (or
/// disjoint from) `feasible_on` depending on whether the catalogue
/// covers this model.
pub locations: Vec<ModelLocation>, pub locations: Vec<ModelLocation>,
} }

View File

@@ -24,6 +24,7 @@ tokio-stream.workspace = true
eventsource-stream.workspace = true eventsource-stream.workspace = true
bytes = "1" bytes = "1"
urlencoding = "2" urlencoding = "2"
url = "2"
[dev-dependencies] [dev-dependencies]
tokio = { workspace = true, features = ["test-util"] } tokio = { workspace = true, features = ["test-util"] }

View File

@@ -34,12 +34,30 @@ async fn chat_completions(
) -> Response { ) -> Response {
let model_id = match extract_model(&body) { let model_id = match extract_model(&body) {
Some(m) => m, Some(m) => m,
None => return error_response(400, "missing 'model' field in request body"), None => {
tracing::warn!(
handler = "chat_completions",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
}; };
let route = match router::resolve(&fleet, &model_id).await { let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r, Ok(r) => r,
Err(e) => return error_response(404, &e.to_string()), Err(e) => {
tracing::warn!(
handler = "chat_completions",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(404, &e.to_string());
}
}; };
touch_model(&fleet, &route.node_name, &model_id).await; touch_model(&fleet, &route.node_name, &model_id).await;
@@ -63,12 +81,30 @@ async fn completions(
) -> Response { ) -> Response {
let model_id = match extract_model(&body) { let model_id = match extract_model(&body) {
Some(m) => m, Some(m) => m,
None => return error_response(400, "missing 'model' field in request body"), None => {
tracing::warn!(
handler = "completions",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
}; };
let route = match router::resolve(&fleet, &model_id).await { let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r, Ok(r) => r,
Err(e) => return error_response(404, &e.to_string()), Err(e) => {
tracing::warn!(
handler = "completions",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(404, &e.to_string());
}
}; };
touch_model(&fleet, &route.node_name, &model_id).await; touch_model(&fleet, &route.node_name, &model_id).await;
@@ -85,7 +121,14 @@ async fn anthropic_messages(
// Parse as Anthropic request. // Parse as Anthropic request.
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) { let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
Ok(r) => r, Ok(r) => r,
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")), Err(e) => {
tracing::warn!(
handler = "anthropic_messages",
error = %e,
"rejected: invalid Anthropic request body"
);
return error_response(400, "invalid Anthropic request body");
}
}; };
let model_id = anth_req.model.clone(); let model_id = anth_req.model.clone();
@@ -95,12 +138,32 @@ async fn anthropic_messages(
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req); let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
let openai_body = match serde_json::to_vec(&openai_req) { let openai_body = match serde_json::to_vec(&openai_req) {
Ok(b) => Bytes::from(b), Ok(b) => Bytes::from(b),
Err(e) => return error_response(500, &format!("translation error: {e}")), Err(e) => {
tracing::error!(
handler = "anthropic_messages",
model = %model_id,
error = %e,
"internal: failed to serialise translated OpenAI request"
);
return error_response(500, "internal translation error");
}
}; };
let route = match router::resolve(&fleet, &model_id).await { let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r, Ok(r) => r,
Err(e) => return error_response(404, &e.to_string()), Err(e) => {
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(404, &e.to_string());
}
}; };
touch_model(&fleet, &route.node_name, &model_id).await; touch_model(&fleet, &route.node_name, &model_id).await;
@@ -133,14 +196,25 @@ async fn anthropic_messages(
Ok(resp) => resp, Ok(resp) => resp,
Err(e) => { Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1); metrics::counter!("cortex_request_errors_total", &labels).increment(1);
// forward_request already warn'd with the wire-level
// detail; no need to log again here.
e.into_response() e.into_response()
} }
} }
} else { } else {
// Non-streaming: proxy, buffer full response, translate back to Anthropic. // Non-streaming: proxy, buffer full response, translate back to Anthropic.
let target_url = format!("{}/v1/chat/completions", route.endpoint);
tracing::info!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
cold_start = route.cold_start,
"proxying request"
);
let upstream_resp = fleet let upstream_resp = fleet
.http_client .http_client
.post(format!("{}/v1/chat/completions", route.endpoint)) .post(&target_url)
.body(openai_body) .body(openai_body)
.header("content-type", "application/json") .header("content-type", "application/json")
.send() .send()
@@ -150,22 +224,49 @@ async fn anthropic_messages(
Ok(r) => r, Ok(r) => r,
Err(e) => { Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1); metrics::counter!("cortex_request_errors_total", &labels).increment(1);
return error_response(502, &format!("upstream request failed: {e}")); tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
"upstream request failed (network)"
);
return error_response(502, "upstream request failed");
} }
}; };
if !upstream_resp.status().is_success() { let upstream_status = upstream_resp.status();
if !upstream_status.is_success() {
metrics::counter!("cortex_request_errors_total", &labels).increment(1); metrics::counter!("cortex_request_errors_total", &labels).increment(1);
let status = upstream_resp.status().as_u16(); let status = upstream_status.as_u16();
let body = upstream_resp.text().await.unwrap_or_default(); let body = upstream_resp.text().await.unwrap_or_default();
return error_response(status, &format!("upstream error: {body}")); let body_snippet = body.chars().take(512).collect::<String>();
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
status,
body = %body_snippet,
"upstream returned non-2xx"
);
return error_response(status, &format!("upstream returned {status}"));
} }
let body_bytes = match upstream_resp.bytes().await { let body_bytes = match upstream_resp.bytes().await {
Ok(b) => b, Ok(b) => b,
Err(e) => { Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1); metrics::counter!("cortex_request_errors_total", &labels).increment(1);
return error_response(502, &format!("failed to read upstream response: {e}")); tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
"failed to read upstream response body"
);
return error_response(502, "failed to read upstream response");
} }
}; };
@@ -174,7 +275,20 @@ async fn anthropic_messages(
Ok(r) => r, Ok(r) => r,
Err(e) => { Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1); metrics::counter!("cortex_request_errors_total", &labels).increment(1);
return error_response(502, &format!("failed to parse upstream response: {e}")); let body_snippet = String::from_utf8_lossy(&body_bytes)
.chars()
.take(512)
.collect::<String>();
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
body = %body_snippet,
"failed to parse upstream response as OpenAI ChatCompletionResponse"
);
return error_response(502, "malformed upstream response");
} }
}; };
@@ -185,12 +299,62 @@ async fn anthropic_messages(
} }
} }
/// `GET /v1/models` — aggregate models from all nodes. /// `GET /v1/models` — union of (catalogue × topology feasibility) and
/// (currently loaded somewhere). The result is what the fleet *could*
/// serve, not just what's already loaded — so OpenAI-compatible tools
/// see every model the operator has provisioned, and cortex
/// transparently cold-loads the first time one is requested.
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> { async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
use std::collections::HashMap;
let now = Utc::now().timestamp() as u64;
let nodes = fleet.nodes.read().await; let nodes = fleet.nodes.read().await;
let mut model_map: std::collections::HashMap<String, CortexModelEntry> = let catalogue = &fleet.catalogue;
std::collections::HashMap::new();
let mut entries: HashMap<String, CortexModelEntry> = HashMap::new();
// Pass 1: catalogue × topology. For every catalogue profile, find
// healthy neurons whose discovered devices satisfy the profile.
// Catalogue-defined models surface here even if nothing has loaded
// them yet — that's the point of the unified endpoint.
for profile in &catalogue.models {
let mut feasible_on = Vec::new();
for node in nodes.values() {
if !node.healthy {
continue;
}
let Some(disc) = node.discovery.as_ref() else {
continue;
};
if profile.is_feasible_on(&node.name, &disc.devices) {
feasible_on.push(node.name.clone());
}
}
if feasible_on.is_empty() {
// The catalogue lists this model but no neuron's topology
// matches — surface it as not-loaded with no feasible
// location. Hides nothing; lets operators see why a
// configured model isn't reachable.
feasible_on.clear();
}
entries.insert(
profile.id.clone(),
CortexModelEntry {
id: profile.id.clone(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: false,
feasible_on,
locations: Vec::new(),
},
);
}
// Pass 2: layer the actually-loaded state on top. For each
// (node, model) entry, attach a ModelLocation. If the model isn't
// in the catalogue, create a new CortexModelEntry from scratch —
// cortex doesn't refuse to surface a manually-loaded model just
// because the operator didn't enumerate it in models.toml.
for node in nodes.values() { for node in nodes.values() {
for (model_id, entry) in &node.models { for (model_id, entry) in &node.models {
let location = ModelLocation { let location = ModelLocation {
@@ -198,19 +362,30 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
status: entry.status, status: entry.status,
vram_estimate_mb: entry.vram_estimate_mb, vram_estimate_mb: entry.vram_estimate_mb,
}; };
model_map let was_loaded = matches!(entry.status, cortex_core::node::ModelStatus::Loaded);
entries
.entry(model_id.clone()) .entry(model_id.clone())
.and_modify(|e| e.locations.push(location.clone())) .and_modify(|e| {
e.locations.push(location.clone());
if was_loaded {
e.loaded = true;
}
})
.or_insert_with(|| CortexModelEntry { .or_insert_with(|| CortexModelEntry {
id: model_id.clone(), id: model_id.clone(),
object: "model".into(), object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: was_loaded,
// Not in catalogue — cortex has no opinion on
// feasibility; leave empty.
feasible_on: Vec::new(),
locations: vec![location], locations: vec![location],
}); });
} }
} }
let data: Vec<Value> = model_map.values().map(|e| json!(e)).collect(); let data: Vec<Value> = entries.values().map(|e| json!(e)).collect();
Json(json!({ Json(json!({
"object": "list", "object": "list",
"data": data, "data": data,
@@ -265,6 +440,9 @@ async fn proxy_with_metrics(
} }
Err(e) => { Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1); metrics::counter!("cortex_request_errors_total", &labels).increment(1);
// proxy::forward_request already warn'd with wire-level
// detail (target URL, error, status). ProxyError::into_response
// now returns a generic message — no body leak.
e.into_response() e.into_response()
} }
} }

View File

@@ -3,6 +3,7 @@
use crate::state::CortexState; use crate::state::CortexState;
use chrono::Utc; use chrono::Utc;
use cortex_core::discovery::DiscoveryResponse;
use cortex_core::harness::ModelInfo; use cortex_core::harness::ModelInfo;
use cortex_core::node::{ModelEntry, ModelStatus}; use cortex_core::node::{ModelEntry, ModelStatus};
use std::sync::Arc; use std::sync::Arc;
@@ -25,7 +26,59 @@ pub async fn poll_once(fleet: &CortexState) {
} }
} }
/// One-shot fetch of `GET /discovery`. Cached on the NodeState forever
/// after the first success — topology is invariant for a given neuron
/// process. Skipped when the cache is already populated.
async fn maybe_poll_discovery(fleet: &CortexState, name: &str, endpoint: &str) {
{
let nodes = fleet.nodes.read().await;
match nodes.get(name) {
Some(n) if n.discovery.is_some() => return,
_ => {}
}
}
let url = format!("{endpoint}/discovery");
let resp = match fleet
.http_client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await
{
Ok(r) if r.status().is_success() => r,
Ok(r) => {
tracing::debug!(node = name, status = %r.status(), "discovery probe non-success");
return;
}
Err(e) => {
tracing::debug!(node = name, error = %e, "discovery probe unreachable");
return;
}
};
match resp.json::<DiscoveryResponse>().await {
Ok(d) => {
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(name) {
tracing::info!(
node = name,
hostname = %d.hostname,
devices = d.devices.len(),
"discovery cached"
);
node.discovery = Some(d);
}
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to parse /discovery response");
}
}
}
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) { async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
// Topology first — cheap once cached, and the router needs it to
// route requests against catalogue entries that aren't loaded yet.
maybe_poll_discovery(fleet, name, endpoint).await;
let url = format!("{endpoint}/models"); let url = format!("{endpoint}/models");
let result = fleet let result = fleet

View File

@@ -12,6 +12,13 @@ use axum::response::{IntoResponse, Response};
use reqwest::Client; use reqwest::Client;
/// Proxy a request body to the resolved backend node and stream the response. /// Proxy a request body to the resolved backend node and stream the response.
///
/// Logging contract: every call emits exactly one structured event at
/// info / warn level for operator visibility, regardless of outcome.
/// Network-level failures and non-2xx upstream statuses are warn'd here
/// (closest to the wire); the user-facing response carries only the
/// status code and a generic message — implementation detail (body,
/// error chain) lives in the log, never in the API surface.
pub async fn forward_request( pub async fn forward_request(
client: &Client, client: &Client,
route: &RouteDecision, route: &RouteDecision,
@@ -37,10 +44,33 @@ pub async fn forward_request(
req_builder = req_builder.header(key, value); req_builder = req_builder.header(key, value);
} }
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?; let upstream_resp = match req_builder.send().await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
node = %route.node_name,
url = %url,
error = %e,
"proxy: upstream request failed (network)"
);
return Err(ProxyError::Upstream(e));
}
};
let status = let upstream_status = upstream_resp.status();
StatusCode::from_u16(upstream_resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); if !upstream_status.is_success() {
// Streaming body — can't snippet without breaking the stream
// pass-through. Log status + URL; the client still gets the
// upstream status, just without the leaked body.
tracing::warn!(
node = %route.node_name,
url = %url,
status = upstream_status.as_u16(),
"proxy: upstream returned non-2xx"
);
}
let status = StatusCode::from_u16(upstream_status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let resp_headers = upstream_resp.headers().clone(); let resp_headers = upstream_resp.headers().clone();
let stream = upstream_resp.bytes_stream(); let stream = upstream_resp.bytes_stream();
@@ -52,28 +82,37 @@ pub async fn forward_request(
response = response.header(key, value); response = response.header(key, value);
} }
response response.body(body).map_err(|e| {
.body(body) tracing::warn!(
.map_err(|e| ProxyError::ResponseBuild(e.to_string())) node = %route.node_name,
url = %url,
error = %e,
"proxy: failed to build response"
);
ProxyError::ResponseBuild(e.to_string())
})
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ProxyError { pub enum ProxyError {
#[error("upstream request failed: {0}")] #[error("upstream request failed")]
Upstream(reqwest::Error), Upstream(reqwest::Error),
#[error("failed to build response: {0}")] #[error("failed to build response")]
ResponseBuild(String), ResponseBuild(String),
} }
impl IntoResponse for ProxyError { impl IntoResponse for ProxyError {
fn into_response(self) -> Response { fn into_response(self) -> Response {
let status = match &self { let (status, message) = match &self {
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY, ProxyError::Upstream(_) => (StatusCode::BAD_GATEWAY, "upstream request failed"),
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR, ProxyError::ResponseBuild(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"failed to build response",
),
}; };
let body = serde_json::json!({ let body = serde_json::json!({
"error": { "error": {
"message": self.to_string(), "message": message,
"type": "proxy_error", "type": "proxy_error",
} }
}); });

View File

@@ -2,13 +2,21 @@
//! //!
//! Given a model ID from an inbound request, determine which node should //! Given a model ID from an inbound request, determine which node should
//! handle it. Priority: //! handle it. Priority:
//! 1. Node where the model is currently `Loaded` //! 1. Node where the model is currently `Loaded` → use it.
//! 2. Node where the model is `Unloaded` (will lazy-load on request) //! 2. Node where the model is `Unloaded` → use it; neuron's existing
//! 3. Error: model not found on any node //! lazy-load behaviour will reload before serving the request.
//! 3. Model is in the catalogue → pick a feasible neuron, call
//! `POST /models/load`, wait for the load to complete, then
//! proxy. First-request cold-load latency is acceptable per the
//! unified-endpoint contract.
//! 4. Not in catalogue, not loaded anywhere → 404.
use crate::state::CortexState; use crate::state::CortexState;
use cortex_core::catalogue::ModelProfile;
use cortex_core::harness::ModelSpec;
use cortex_core::node::ModelStatus; use cortex_core::node::ModelStatus;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
/// The routing decision: which node endpoint to proxy the request to. /// The routing decision: which node endpoint to proxy the request to.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -16,18 +24,31 @@ pub struct RouteDecision {
pub node_name: String, pub node_name: String,
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint). /// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
pub endpoint: String, pub endpoint: String,
/// Whether the model will need to load (cold start). /// Whether the model will need to load (cold start). Set to true
/// when we proxied to an `Unloaded` node (lazy load on neuron) or
/// when we just triggered an explicit cold-load via the catalogue
/// path.
pub cold_start: bool, pub cold_start: bool,
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum RouteError { pub enum RouteError {
#[error("model '{0}' not found on any node")] #[error("model '{0}' not found on any node and not in catalogue")]
ModelNotFound(String), ModelNotFound(String),
#[error("no healthy nodes available")] #[error("no healthy nodes available")]
NoHealthyNodes, NoHealthyNodes,
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")] #[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
EndpointResolveFailed(String, String), EndpointResolveFailed(String, String),
#[error(
"model '{model_id}' is in the catalogue but no healthy neuron's topology satisfies its constraints"
)]
NoFeasibleNeuron { model_id: String },
#[error("cold-load of '{model_id}' on '{node}' failed: {message}")]
ColdLoadFailed {
model_id: String,
node: String,
message: String,
},
} }
/// Resolve which node should serve a request for the given model. /// Resolve which node should serve a request for the given model.
@@ -36,42 +57,231 @@ pub async fn resolve(
fleet: &Arc<CortexState>, fleet: &Arc<CortexState>,
model_id: &str, model_id: &str,
) -> Result<RouteDecision, RouteError> { ) -> Result<RouteDecision, RouteError> {
let (node_name, neuron_endpoint, cold_start) = { // Snapshot loaded / unloaded state from the poller cache.
let (loaded_route, unloaded_route, any_healthy) = {
let nodes = fleet.nodes.read().await; let nodes = fleet.nodes.read().await;
let mut loaded_route = None;
let mut loaded_candidate = None; let mut unloaded_route = None;
let mut unloaded_candidate = None; let mut any_healthy = false;
for node in nodes.values() { for node in nodes.values() {
if !node.healthy { if !node.healthy {
continue; continue;
} }
any_healthy = true;
if let Some(entry) = node.models.get(model_id) { if let Some(entry) = node.models.get(model_id) {
match entry.status { match entry.status {
ModelStatus::Loaded | ModelStatus::Reloading => { ModelStatus::Loaded | ModelStatus::Reloading => {
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false)); loaded_route = Some((node.name.clone(), node.endpoint.clone(), false));
break; break;
} }
ModelStatus::Unloaded => { ModelStatus::Unloaded => {
if unloaded_candidate.is_none() { if unloaded_route.is_none() {
unloaded_candidate = unloaded_route = Some((node.name.clone(), node.endpoint.clone(), true));
Some((node.name.clone(), node.endpoint.clone(), true));
} }
} }
} }
} }
} }
(loaded_route, unloaded_route, any_healthy)
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
if nodes.values().any(|n| n.healthy) {
RouteError::ModelNotFound(model_id.to_string())
} else {
RouteError::NoHealthyNodes
}
})?
}; };
// Ask the neuron for the inference endpoint for this model. if !any_healthy {
return Err(RouteError::NoHealthyNodes);
}
// Priority 1: already loaded.
if let Some((node_name, neuron_endpoint, cold_start)) = loaded_route {
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
}
// Priority 2: known to neuron but unloaded (neuron's lazy load).
if let Some((node_name, neuron_endpoint, cold_start)) = unloaded_route {
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
}
// Priority 3: catalogue × topology cold-load.
if let Some(profile) = fleet.catalogue.get(model_id) {
let (node_name, neuron_endpoint) = pick_feasible_neuron(fleet, profile).await?;
cold_load(fleet, &node_name, &neuron_endpoint, profile).await?;
return finish(fleet, &node_name, &neuron_endpoint, model_id, true).await;
}
Err(RouteError::ModelNotFound(model_id.to_string()))
}
/// Pick a healthy neuron whose discovered topology satisfies the
/// profile. Preference order:
/// 1. A neuron from `profile.pinned_on` that is healthy + feasible.
/// 2. Otherwise, any healthy + feasible neuron, stable by name.
async fn pick_feasible_neuron(
fleet: &Arc<CortexState>,
profile: &ModelProfile,
) -> Result<(String, String), RouteError> {
let nodes = fleet.nodes.read().await;
let mut candidates: Vec<(String, String, bool)> = Vec::new();
for node in nodes.values() {
if !node.healthy {
continue;
}
let Some(disc) = node.discovery.as_ref() else {
continue;
};
if !profile.is_feasible_on(&node.name, &disc.devices) {
continue;
}
let pinned = profile.pinned_on.iter().any(|n| n == &node.name);
candidates.push((node.name.clone(), node.endpoint.clone(), pinned));
}
candidates.sort_by(|a, b| {
b.2.cmp(&a.2) // pinned first (true > false)
.then(a.0.cmp(&b.0))
});
let pick = candidates.into_iter().next();
pick.map(|(n, e, _)| (n, e))
.ok_or_else(|| RouteError::NoFeasibleNeuron {
model_id: profile.id.clone(),
})
}
/// Issue `POST {endpoint}/models/load` for this profile on this neuron,
/// blocking until the load completes (neuron's load endpoint is
/// synchronous — it returns 200 once VRAM is materialised). On success
/// also inserts a `Loaded` entry into the local NodeState cache so the
/// caller's subsequent endpoint lookup sees the new model without
/// waiting for the next poll cycle.
async fn cold_load(
fleet: &Arc<CortexState>,
node_name: &str,
neuron_endpoint: &str,
profile: &ModelProfile,
) -> Result<(), RouteError> {
let spec = profile_to_spec(fleet, node_name, profile).await;
let url = format!("{neuron_endpoint}/models/load");
tracing::info!(model = %profile.id, node = node_name, "cold-loading via /models/load");
// Generous timeout: a fresh download + safetensors mmap + device
// copy for a 30B-class dense model can comfortably exceed 5 min on
// a slow link. The HTTP client's own default already covers most
// of this; pin a longer per-request bound just here.
let resp = match fleet
.http_client
.post(&url)
.timeout(Duration::from_secs(1800))
.json(&spec)
.send()
.await
{
Ok(r) => r,
Err(e) => {
return Err(RouteError::ColdLoadFailed {
model_id: profile.id.clone(),
node: node_name.to_string(),
message: format!("HTTP request failed: {e}"),
});
}
};
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
// Neuron returns 400 "already loaded" when two concurrent
// requests race the same model. Treat that as success — both
// requests effectively achieved the same end state.
if body.contains("already loaded") {
tracing::info!(
model = %profile.id,
node = node_name,
"cold-load saw 'already loaded' — treating as success"
);
} else {
return Err(RouteError::ColdLoadFailed {
model_id: profile.id.clone(),
node: node_name.to_string(),
message: format!("HTTP {status}: {body}"),
});
}
} else {
tracing::info!(model = %profile.id, node = node_name, "cold-load returned 200");
}
// Warm the cache: insert a Loaded ModelEntry so the next
// resolve() finds the model without waiting for the poll loop.
{
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(node_name) {
node.models.insert(
profile.id.clone(),
cortex_core::node::ModelEntry {
id: profile.id.clone(),
status: ModelStatus::Loaded,
last_accessed: Some(chrono::Utc::now()),
vram_estimate_mb: profile.vram_mb,
},
);
}
}
Ok(())
}
/// Translate a `ModelProfile` to a `ModelSpec` neuron's /models/load
/// accepts. Devices are picked from the neuron's discovered topology —
/// the first `min_devices` indices that meet `min_device_vram_mb`.
async fn profile_to_spec(
fleet: &Arc<CortexState>,
node_name: &str,
profile: &ModelProfile,
) -> ModelSpec {
let devices = {
let nodes = fleet.nodes.read().await;
let mut picked: Vec<u32> = Vec::new();
if let Some(node) = nodes.get(node_name)
&& let Some(disc) = &node.discovery
{
let min_vram = profile.min_device_vram_mb.unwrap_or(0);
for d in &disc.devices {
if d.vram_total_mb >= min_vram {
picked.push(d.index);
if picked.len() as u32 >= profile.min_devices {
break;
}
}
}
}
if picked.is_empty() {
// Fall back to a 0..min_devices default; pick_feasible_neuron
// already verified the topology satisfies the constraints,
// so this only fires if discovery raced or was lost.
(0..profile.min_devices).collect()
} else {
picked
}
};
let tensor_parallel = if profile.min_devices > 1 {
Some(profile.min_devices)
} else {
None
};
ModelSpec {
model_id: profile.id.clone(),
harness: profile.harness.clone(),
quant: profile.quant.clone(),
tensor_parallel,
devices: Some(devices),
}
}
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
/// build the final `RouteDecision`. Shared by all three priority
/// branches above.
async fn finish(
fleet: &Arc<CortexState>,
node_name: &str,
neuron_endpoint: &str,
model_id: &str,
cold_start: bool,
) -> Result<RouteDecision, RouteError> {
let endpoint_url = format!( let endpoint_url = format!(
"{}/models/{}/endpoint", "{}/models/{}/endpoint",
neuron_endpoint, neuron_endpoint,
@@ -89,13 +299,82 @@ pub async fn resolve(
_ => None, _ => None,
}; };
let endpoint = inference_endpoint.ok_or_else(|| { let raw = inference_endpoint.ok_or_else(|| {
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone()) RouteError::EndpointResolveFailed(model_id.to_string(), node_name.to_string())
})?; })?;
// Rewrite loopback inference URLs to use the configured neuron host.
// Neuron's default bind_url is `http://localhost:13131` (it can't
// reliably know its own externally-resolvable name). Cortex sees a
// URL that's only meaningful from the neuron host's own perspective;
// proxying directly to localhost from a different cortex host would
// hit nothing. Keep neuron's port and path (a future harness could
// serve inference on a different port than the management API), but
// swap the host for the one in cortex.toml.
let endpoint = rewrite_loopback_host(&raw, neuron_endpoint).unwrap_or(raw);
Ok(RouteDecision { Ok(RouteDecision {
node_name, node_name: node_name.to_string(),
endpoint, endpoint,
cold_start, cold_start,
}) })
} }
/// If `inference_url`'s host is a loopback name (localhost / 127.0.0.1 /
/// 0.0.0.0 / ::1), return a copy with the host replaced by
/// `neuron_endpoint`'s host. Otherwise return None and the caller falls
/// back to the inference URL as-is.
fn rewrite_loopback_host(inference_url: &str, neuron_endpoint: &str) -> Option<String> {
let inf = url::Url::parse(inference_url).ok()?;
let inf_host = inf.host_str()?;
let is_loopback = matches!(inf_host, "localhost" | "127.0.0.1" | "0.0.0.0" | "::1");
if !is_loopback {
return None;
}
let neuron = url::Url::parse(neuron_endpoint).ok()?;
let new_host = neuron.host_str()?;
let mut out = inf.clone();
out.set_host(Some(new_host)).ok()?;
// url::Url::to_string normalises an empty path to "/", which then
// breaks downstream callers that do format!("{endpoint}/v1/...")
// and produce a double slash. The proxy URL is treated as a base
// string that the caller appends paths to, so strip the trailing
// slash here.
let s = out.to_string();
Some(s.trim_end_matches('/').to_string())
}
#[cfg(test)]
mod tests {
use super::rewrite_loopback_host;
#[test]
fn rewrites_localhost_keeps_port_and_path() {
let out = rewrite_loopback_host(
"http://localhost:13131",
"http://beast.hanzalova.internal:13131",
);
assert_eq!(
out.as_deref(),
Some("http://beast.hanzalova.internal:13131")
);
}
#[test]
fn rewrites_loopback_with_distinct_inference_port() {
let out = rewrite_loopback_host("http://127.0.0.1:8080", "http://beast.lan:13131");
assert_eq!(out.as_deref(), Some("http://beast.lan:8080"));
}
#[test]
fn leaves_non_loopback_alone() {
let out = rewrite_loopback_host("http://other.host:1234", "http://beast.lan:13131");
assert_eq!(out, None);
}
#[test]
fn malformed_inference_url_returns_none() {
let out = rewrite_loopback_host("not a url", "http://beast.lan:13131");
assert_eq!(out, None);
}
}

View File

@@ -26,6 +26,7 @@ impl CortexState {
models: HashMap::new(), models: HashMap::new(),
lifecycle_cycles: 0, lifecycle_cycles: 0,
last_poll: None, last_poll: None,
discovery: None,
}, },
); );
} }

View File

@@ -14,12 +14,18 @@ path = "src/main.rs"
[features] [features]
default = [] default = []
# Enables CUDA acceleration in candle. Without this feature, candle # Enables CUDA acceleration in candle and the cudarc/nccl bindings the
# compiles for CPU only and Device::new_cuda calls fall back to CPU. # TP worker pool uses. Without this feature, candle compiles for CPU
# only, Device::new_cuda calls fall back to CPU, and TP Init/sanity
# requests return Error{kind="cuda_feature_not_enabled"}.
cuda = [ cuda = [
"candle-core/cuda", "candle-core/cuda",
"candle-core/nccl",
"candle-nn/cuda", "candle-nn/cuda",
"candle-transformers/cuda", "candle-transformers/cuda",
"dep:cudarc",
"dep:half",
"dep:cudaforge",
] ]
# Use cuDNN for convolution / attention kernels. Requires CUDA. # Use cuDNN for convolution / attention kernels. Requires CUDA.
cudnn = [ cudnn = [
@@ -60,9 +66,32 @@ toml.workspace = true
candle-core = "0.10.2" candle-core = "0.10.2"
candle-nn = "0.10.2" candle-nn = "0.10.2"
candle-transformers = "0.10.2" candle-transformers = "0.10.2"
# Direct dep on cudarc (matching candle's transitive version) so the
# TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on
# the `cuda` feature; same toolchain requirement as candle's CUDA path.
cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] }
# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle
# storages. Matches candle-core's pinned major version to avoid double-
# compiling the `half` crate at conflicting versions.
half = { version = "2.5", optional = true }
tokenizers = { version = "0.22", default-features = false, features = ["onig"] } tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
hf-hub = { version = "0.4", features = ["tokio"] } hf-hub = { version = "0.4", features = ["tokio"] }
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
# tp `fused_load` module to read per-rank slices of fused QKV tensors
# without materialising the full tensor on device.
safetensors = "0.7"
[dev-dependencies] [dev-dependencies]
tokio = { workspace = true, features = ["test-util"] } tokio = { workspace = true, features = ["test-util"] }
reqwest.workspace = true reqwest.workspace = true
[build-dependencies]
# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a`
# under the `cuda` feature. Matches mistralrs's upstream build setup
# (their `mistralrs-core/build.rs` uses the same constructor).
cudaforge = { version = "0.1", optional = true }
[package.metadata.docs.rs]
# Skip the CUDA path on docs.rs (it lacks nvcc).
no-default-features = true

66
crates/neuron/build.rs Normal file
View File

@@ -0,0 +1,66 @@
//! Build script: compile the CUDA kernels in `src/cuda/*.cu` into a
//! static library and link it under the `cuda` feature.
//!
//! Patterned on `EricLBuehler/mistral.rs::mistralrs-core/build.rs` —
//! same `cudaforge::KernelBuilder` invocation, same NVCC flag set.
fn main() {
#[cfg(feature = "cuda")]
{
use std::path::PathBuf;
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/cuda/");
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let mut builder = cudaforge::KernelBuilder::new()
.source_glob("src/cuda/*.cu")
.out_dir(&build_dir)
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--compiler-options")
.arg("-fPIC");
// sm_<80 doesn't have bf16 intrinsics for WMMA — gate the
// bf16-only kernels off in that case. (Mirrors upstream.)
if let Some(compute_cap) = builder.get_compute_cap()
&& compute_cap < 80
{
builder = builder.arg("-DNO_BF16_KERNEL");
}
let target = std::env::var("TARGET").unwrap();
let out_file = if target.contains("msvc") {
build_dir.join("neuroncuda.lib")
} else {
build_dir.join("libneuroncuda.a")
};
builder
.build_lib(out_file)
.expect("neuron cuda build failed");
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=neuroncuda");
println!("cargo:rustc-link-lib=dylib=cudart");
if target.contains("msvc") {
// No extra runtime library needed.
} else if target.contains("apple")
|| target.contains("freebsd")
|| target.contains("openbsd")
{
println!("cargo:rustc-link-lib=dylib=c++");
} else if target.contains("android") {
println!("cargo:rustc-link-lib=dylib=c++_shared");
} else {
println!("cargo:rustc-link-lib=dylib=stdc++");
}
}
}

View File

@@ -0,0 +1,84 @@
//! FFI declarations for the CUDA kernels in `gdn.cu`.
//!
//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs`
//! covering only the Gated DeltaNet kernels we currently use. Other
//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective
//! scan, etc.) would land here too as we absorb them.
//!
//! All function declarations are MIT-licensed from upstream and
//! unchanged apart from this header.
use std::ffi::c_void;
#[allow(dead_code)]
unsafe extern "C" {
// GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next.
pub(crate) fn gated_delta_rule_recurrence(
q: *const f32,
k: *const f32,
v: *const f32,
g: *const f32,
beta: *const f32,
state: *mut f32,
output: *mut f32,
bh: i32,
seq_len: i32,
k_dim: i32,
v_dim: i32,
stream: i64,
);
/// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks).
pub(crate) fn chunked_gated_delta_rule_recurrence(
q: *const f32,
k: *const f32,
v: *const f32,
g: *const f32,
beta: *const f32,
state: *mut f32,
output: *mut f32,
bh: i32,
seq_len: i32,
k_dim: i32,
v_dim: i32,
stream: i64,
);
pub(crate) fn causal_conv1d_update(
x: *const c_void,
weight: *const c_void,
conv_state: *mut c_void,
output: *mut c_void,
batch_size: i32,
conv_dim: i32,
kernel_size: i32,
dtype: i32,
stream: i64,
);
pub(crate) fn causal_conv1d_full(
x: *const c_void,
weight: *const c_void,
conv_state_out: *mut c_void,
output: *mut c_void,
batch_size: i32,
conv_dim: i32,
seq_len: i32,
kernel_size: i32,
dtype: i32,
stream: i64,
);
pub(crate) fn fused_gdn_gating(
b: *const c_void,
a: *const c_void,
a_log: *const f32,
dt_bias: *const f32,
beta_out: *mut c_void,
g_out: *mut c_void,
total_elements: i32,
num_heads: i32,
dtype: i32,
stream: i64,
);
}

View File

@@ -0,0 +1,711 @@
// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`).
//
// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this
// file are limited to this banner; the kernels are unchanged so a
// diff against upstream stays minimal.
//
// Five kernels exposed via `extern "C"` shims at the bottom:
// - gated_delta_rule_recurrence (per-token decode)
// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill)
// - causal_conv1d_update (single-token conv decode)
// - causal_conv1d_full (multi-token conv prefill)
// - fused_gdn_gating (beta = sigmoid(b);
// g = -exp(A_log) * softplus(a + dt_bias))
#include "cuda_bf16.h"
#include "cuda_fp16.h"
#include <cmath>
#include <cstdint>
#include <cuda_runtime.h>
// ============================================================================
// Kernel 1: gated_delta_rule_recurrence (optimized)
//
// V-tiled recurrence with compile-time K dimension for register residency.
// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of
// state. Shared memory holds k_buf and q_buf (2*BK floats).
//
// Optimizations over naive version:
// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30)
// - #pragma unroll on all k-loops -> full ILP
// - Fused decay+kv_mem pass and fused state_update+output pass
// - __fmaf_rn intrinsics for guaranteed fused multiply-add
// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere
//
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
// state: [BH, K, V] (in/out) output: [BH, S, V]
// ============================================================================
// Optimized kernel: BK known at compile time -> registers + full unrolling
template <int BK, int BV>
__global__ void gated_delta_rule_recurrence_kernel_tiled(
const float *__restrict__ q, // [BH, S, K]
const float *__restrict__ k, // [BH, S, K]
const float *__restrict__ v, // [BH, S, V]
const float *__restrict__ g, // [BH, S]
const float *__restrict__ beta, // [BH, S]
float *__restrict__ state, // [BH, K, V]
float *__restrict__ output, // [BH, S, V]
int seq_len, int v_dim) {
const int v_tile = blockIdx.x; // which V-tile
const int bh = blockIdx.y; // batch*head index
const int tid = threadIdx.x; // thread within tile [0, BV)
const int v_idx = v_tile * BV + tid; // global V index
if (v_idx >= v_dim)
return;
// Pointers for this (batch, head)
const float *q_bh = q + bh * seq_len * BK;
const float *k_bh = k + bh * seq_len * BK;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * BK * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
// Shared memory: k_buf[BK] + q_buf[BK]
__shared__ float k_buf[BK];
__shared__ float q_buf[BK];
// Load state column into registers — BK is compile-time, so this is
// a true register array (not spilled to local memory)
float s[BK];
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
for (int t = 0; t < seq_len; t++) {
// Collaboratively load k_t into shared memory
// BK / BV loads per thread (e.g. 128/64 = 2)
#pragma unroll
for (int j = tid; j < BK; j += BV) {
k_buf[j] = k_bh[t * BK + j];
}
__syncthreads();
// Load scalars for this timestep
float decay = expf(g_bh[t]);
float beta_t = beta_bh[t];
float v_t = v_bh[t * v_dim + v_idx];
// Fused pass 1: decay state + compute kv_mem
float kv_mem = 0.0f;
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] *= decay;
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
}
// Delta rule
float delta = (v_t - kv_mem) * beta_t;
// Collaboratively load q_t into shared memory
#pragma unroll
for (int j = tid; j < BK; j += BV) {
q_buf[j] = q_bh[t * BK + j];
}
__syncthreads();
// Fused pass 2: update state + compute output
float y_t = 0.0f;
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
}
out_bh[t * v_dim + v_idx] = y_t;
__syncthreads();
}
// Write state back
#pragma unroll
for (int j = 0; j < BK; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
// Fallback kernel: runtime k_dim, still V-tiled for occupancy
template <int BV, int MAX_K>
__global__ void gated_delta_rule_recurrence_kernel_fallback(
const float *__restrict__ q, const float *__restrict__ k,
const float *__restrict__ v, const float *__restrict__ g,
const float *__restrict__ beta, float *__restrict__ state,
float *__restrict__ output, int seq_len, int k_dim, int v_dim) {
const int v_tile = blockIdx.x;
const int bh = blockIdx.y;
const int tid = threadIdx.x;
const int v_idx = v_tile * BV + tid;
if (v_idx >= v_dim)
return;
const float *q_bh = q + bh * seq_len * k_dim;
const float *k_bh = k + bh * seq_len * k_dim;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * k_dim * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
extern __shared__ float shared[];
float *k_buf = shared;
float *q_buf = shared + k_dim;
float s[MAX_K];
for (int j = 0; j < k_dim; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
for (int t = 0; t < seq_len; t++) {
for (int j = tid; j < k_dim; j += BV) {
k_buf[j] = k_bh[t * k_dim + j];
}
__syncthreads();
float decay = expf(g_bh[t]);
float beta_t = beta_bh[t];
float v_t = v_bh[t * v_dim + v_idx];
float kv_mem = 0.0f;
for (int j = 0; j < k_dim; j++) {
s[j] *= decay;
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
}
float delta = (v_t - kv_mem) * beta_t;
for (int j = tid; j < k_dim; j += BV) {
q_buf[j] = q_bh[t * k_dim + j];
}
__syncthreads();
float y_t = 0.0f;
for (int j = 0; j < k_dim; j++) {
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
}
out_bh[t * v_dim + v_idx] = y_t;
__syncthreads();
}
for (int j = 0; j < k_dim; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
extern "C" void gated_delta_rule_recurrence(const float *q, const float *k,
const float *v, const float *g,
const float *beta, float *state,
float *output, int bh, int seq_len,
int k_dim, int v_dim,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
if (k_dim == 128) {
// Fast path for Qwen3-Next (k_dim=128)
constexpr int BK = 128;
constexpr int BV = 64;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
v_dim);
} else if (k_dim == 64) {
// Fast path for models with k_dim=64
constexpr int BK = 64;
constexpr int BV = 64;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
v_dim);
} else {
// Fallback for other k_dim values (runtime loop, still V-tiled)
constexpr int BV = 64;
constexpr int MAX_K = 256;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
size_t smem = 2 * k_dim * sizeof(float);
gated_delta_rule_recurrence_kernel_fallback<BV, MAX_K>
<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, k_dim, v_dim);
}
}
// ============================================================================
// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization)
//
// Processes prefill tokens in BT-token chunks instead of one at a time.
// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation,
// forward substitution (triangular solve), output computation, and state
// update.
//
// Same thread model as Kernel 1: one block per (v_tile, batch*head),
// one thread per V-column. Each thread owns BK registers of state.
//
// Shared memory holds:
// k_chunk[BT * BK] -- key vectors for current chunk
// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix
// gcum[BT] -- cumulative sum of g within chunk
// beta_s[BT] -- beta values for chunk
// q_buf[BK] -- q vector (loaded one row at a time)
//
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
// state: [BH, K, V] (in/out) output: [BH, S, V]
// ============================================================================
template <int BT, int BK, int BV>
__global__ void
chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K]
const float *__restrict__ k, // [BH, S, K]
const float *__restrict__ v, // [BH, S, V]
const float *__restrict__ g, // [BH, S]
const float *__restrict__ beta, // [BH, S]
float *__restrict__ state, // [BH, K, V]
float *__restrict__ output, // [BH, S, V]
int seq_len, int v_dim) {
const int v_tile = blockIdx.x;
const int bh = blockIdx.y;
const int tid = threadIdx.x;
const int v_idx = v_tile * BV + tid;
if (v_idx >= v_dim)
return;
const int num_chunks = (seq_len + BT - 1) / BT;
// Pointers for this (batch, head)
const float *q_bh = q + bh * seq_len * BK;
const float *k_bh = k + bh * seq_len * BK;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * BK * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
// Dynamic shared memory layout
extern __shared__ float smem[];
float *k_chunk = smem; // [BT * BK]
float *kk_dot = smem + BT * BK; // [BT * BT]
float *gcum = smem + BT * BK + BT * BT; // [BT]
float *beta_s = gcum + BT; // [BT]
float *q_buf = beta_s + BT; // [BK]
// Load state column into registers
float s[BK];
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
// Per-thread register array for corrected deltas
float delta[BT];
for (int c = 0; c < num_chunks; c++) {
const int chunk_start = c * BT;
const int chunk_len = min(BT, seq_len - chunk_start);
// === Phase 1: Cooperative load of k, beta, g into shared memory ===
for (int t = 0; t < chunk_len; t++) {
for (int j = tid; j < BK; j += BV) {
k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j];
}
}
if (tid < chunk_len) {
beta_s[tid] = beta_bh[chunk_start + tid];
gcum[tid] = g_bh[chunk_start + tid];
}
__syncthreads();
// === Phase 1b: Parallel prefix sum of g (Hillis-Steele) ===
for (int stride = 1; stride < BT; stride <<= 1) {
float prev = 0.0f;
if (tid < chunk_len && (int)tid >= stride)
prev = gcum[tid - stride];
__syncthreads();
if (tid < chunk_len && (int)tid >= stride)
gcum[tid] += prev;
__syncthreads();
}
// === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i ===
// Only lower-triangular entries needed (strictly lower)
for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) {
int i = idx / chunk_len;
int j = idx % chunk_len;
if (j < i) {
float dot = 0.0f;
for (int d = 0; d < BK; d++) {
dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot);
}
kk_dot[i * BT + j] = dot;
}
}
__syncthreads();
// === Phase 3: Forward substitution (per V-column, in registers) ===
// Computes corrected delta values via triangular solve
for (int i = 0; i < chunk_len; i++) {
float v_i = v_bh[(chunk_start + i) * v_dim + v_idx];
float decay_i = expf(gcum[i]);
float beta_i = beta_s[i];
// Inter-chunk contribution: state @ k[i] with decay
float kv_mem = 0.0f;
#pragma unroll
for (int d = 0; d < BK; d++) {
kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem);
}
float rhs = beta_i * (v_i - kv_mem);
// Subtract lower-triangular contributions (intra-chunk)
for (int j = 0; j < i; j++) {
float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]);
rhs -= a_ij * delta[j];
}
delta[i] = rhs;
}
// === Phase 4: Output computation (per V-column) ===
for (int i = 0; i < chunk_len; i++) {
// Cooperatively load q[i] into shared
for (int j = tid; j < BK; j += BV) {
q_buf[j] = q_bh[(chunk_start + i) * BK + j];
}
__syncthreads();
float decay_i = expf(gcum[i]);
// Inter-chunk contribution: q[i] @ (state * decay)
float o_val = 0.0f;
#pragma unroll
for (int d = 0; d < BK; d++) {
o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val);
}
// Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] *
// exp(gcum[i] - gcum[j])
for (int j = 0; j <= i; j++) {
float qk_dot = 0.0f;
for (int d = 0; d < BK; d++) {
qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot);
}
o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]);
}
out_bh[(chunk_start + i) * v_dim + v_idx] = o_val;
__syncthreads();
}
// === Phase 5: State update for next chunk ===
float g_total = gcum[chunk_len - 1];
#pragma unroll
for (int d = 0; d < BK; d++) {
float s_new = s[d] * expf(g_total);
for (int t = 0; t < chunk_len; t++) {
s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]);
}
s[d] = s_new;
}
__syncthreads();
}
// Write final state back
#pragma unroll
for (int j = 0; j < BK; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
extern "C" void chunked_gated_delta_rule_recurrence(
const float *q, const float *k, const float *v, const float *g,
const float *beta, float *state, float *output, int bh, int seq_len,
int k_dim, int v_dim, int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
if (k_dim == 128) {
constexpr int BT = 64;
constexpr int BK = 128;
constexpr int BV = 64;
// Shared memory: BT*BK + BT*BT + BT + BT + BK floats
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
// Request extended shared memory
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem);
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, v_dim);
} else if (k_dim == 64) {
constexpr int BT = 64;
constexpr int BK = 64;
constexpr int BV = 64;
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem);
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, v_dim);
} else {
// Fallback: use the sequential kernel for unsupported k_dim
gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len,
k_dim, v_dim, stream);
}
}
// ============================================================================
// Kernel 2a: causal_conv1d_update (decode path, single step)
//
// Each thread handles one channel: shift conv_state left by 1,
// insert new value, dot product with weight, apply SiLU.
//
// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
// conv_state: [B, conv_dim, kernel_size] (in/out)
// output: [B, conv_dim, 1]
// ============================================================================
template <typename T>
__global__ void causal_conv1d_update_kernel(
const T *__restrict__ x, // [B, conv_dim, 1]
const T *__restrict__ weight, // [conv_dim, kernel_size]
T *__restrict__ conv_state, // [B, conv_dim, kernel_size]
T *__restrict__ output, // [B, conv_dim, 1]
int batch_size, int conv_dim, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int b = blockIdx.y;
if (ch >= conv_dim || b >= batch_size)
return;
// Pointer to this batch/channel's conv state
T *cs = conv_state + (b * conv_dim + ch) * kernel_size;
const T *w = weight + ch * kernel_size;
// Shift state left by 1
for (int i = 0; i < kernel_size - 1; i++) {
cs[i] = cs[i + 1];
}
// Insert new value
cs[kernel_size - 1] = x[b * conv_dim + ch];
// Dot product with weight
float acc = 0.0f;
for (int i = 0; i < kernel_size; i++) {
acc += (float)cs[i] * (float)w[i];
}
// SiLU activation: x * sigmoid(x)
float sig = 1.0f / (1.0f + expf(-acc));
float result = acc * sig;
output[b * conv_dim + ch] = (T)result;
}
extern "C" void causal_conv1d_update(const void *x, const void *weight,
void *conv_state, void *output,
int batch_size, int conv_dim,
int kernel_size, int dtype,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
dim3 block(256);
dim3 grid((conv_dim + 255) / 256, batch_size);
if (dtype == 0) {
// f16
causal_conv1d_update_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)x, (const __half *)weight, (__half *)conv_state,
(__half *)output, batch_size, conv_dim, kernel_size);
} else {
// bf16
causal_conv1d_update_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
(__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size,
conv_dim, kernel_size);
}
}
// ============================================================================
// Kernel 2b: causal_conv1d_full (prefill path)
//
// Each thread handles one (channel, position): causal window with
// zero-padding, dot product with weight, SiLU.
// A second pass writes the conv_state from the last kernel_size positions.
//
// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S]
// ============================================================================
template <typename T>
__global__ void causal_conv1d_full_kernel(
const T *__restrict__ x, // [B, conv_dim, S]
const T *__restrict__ weight, // [conv_dim, kernel_size]
T *__restrict__ output, // [B, conv_dim, S]
int batch_size, int conv_dim, int seq_len, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int pos = blockIdx.y;
const int b = blockIdx.z;
if (ch >= conv_dim || pos >= seq_len || b >= batch_size)
return;
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
const T *w = weight + ch * kernel_size;
// Causal convolution: sum over kernel_size window ending at pos
float acc = 0.0f;
for (int i = 0; i < kernel_size; i++) {
int src_pos = pos - (kernel_size - 1) + i;
float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f;
acc += x_val * (float)w[i];
}
// SiLU
float sig = 1.0f / (1.0f + expf(-acc));
float result = acc * sig;
output[(b * conv_dim + ch) * seq_len + pos] = (T)result;
}
template <typename T>
__global__ void save_conv_state_kernel(
const T *__restrict__ x, // [B, conv_dim, S]
T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size]
int batch_size, int conv_dim, int seq_len, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int b = blockIdx.y;
if (ch >= conv_dim || b >= batch_size)
return;
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size;
// Save last kernel_size positions (zero-pad if seq_len < kernel_size)
int pad = kernel_size - seq_len;
for (int i = 0; i < kernel_size; i++) {
if (i < pad) {
cs[i] = (T)0.0f;
} else {
cs[i] = x_bch[seq_len - kernel_size + i];
}
}
}
extern "C" void causal_conv1d_full(const void *x, const void *weight,
void *conv_state_out, void *output,
int batch_size, int conv_dim, int seq_len,
int kernel_size, int dtype, int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
// Main convolution kernel
dim3 block(256);
dim3 grid((conv_dim + 255) / 256, seq_len, batch_size);
if (dtype == 0) {
causal_conv1d_full_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)x, (const __half *)weight, (__half *)output, batch_size,
conv_dim, seq_len, kernel_size);
// Save conv state
dim3 grid2((conv_dim + 255) / 256, batch_size);
save_conv_state_kernel<__half><<<grid2, block, 0, custream>>>(
(const __half *)x, (__half *)conv_state_out, batch_size, conv_dim,
seq_len, kernel_size);
} else {
causal_conv1d_full_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
(__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size);
dim3 grid2((conv_dim + 255) / 256, batch_size);
save_conv_state_kernel<__nv_bfloat16><<<grid2, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size,
conv_dim, seq_len, kernel_size);
}
}
// ============================================================================
// Kernel 3: fused_gdn_gating
//
// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
// a_log and dt_bias are per-head (broadcast over batch*seq).
//
// b, a: [total] a_log, dt_bias: [num_heads]
// beta_out, g_out: [total]
// ============================================================================
template <typename T>
__global__ void
fused_gdn_gating_kernel(const T *__restrict__ b, // [total]
const T *__restrict__ a, // [total]
const float *__restrict__ a_log, // [num_heads]
const float *__restrict__ dt_bias, // [num_heads]
T *__restrict__ beta_out, // [total]
T *__restrict__ g_out, // [total]
int total_elements, int num_heads) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_elements)
return;
// Head index: elements are laid out as [..., num_heads]
int head_idx = idx % num_heads;
// beta = sigmoid(b)
float b_val = (float)b[idx];
float beta = 1.0f / (1.0f + expf(-b_val));
// g = -exp(a_log) * softplus(a + dt_bias)
float a_val = (float)a[idx];
float a_log_val = a_log[head_idx];
float dt_bias_val = dt_bias[head_idx];
float sp_input = a_val + dt_bias_val;
float softplus_val = logf(1.0f + expf(sp_input));
float g_val = -expf(a_log_val) * softplus_val;
beta_out[idx] = (T)beta;
g_out[idx] = (T)g_val;
}
extern "C" void fused_gdn_gating(const void *b, const void *a,
const float *a_log, const float *dt_bias,
void *beta_out, void *g_out,
int total_elements, int num_heads, int dtype,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
dim3 block(256);
dim3 grid((total_elements + 255) / 256);
if (dtype == 0) {
fused_gdn_gating_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)b, (const __half *)a, a_log, dt_bias,
(__half *)beta_out, (__half *)g_out, total_elements, num_heads);
} else {
fused_gdn_gating_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias,
(__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements,
num_heads);
}
}

View File

@@ -0,0 +1,486 @@
//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`.
//!
//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in
//! this file are this header comment — the FFI path module name is
//! `crate::cuda::ffi`, identical to upstream's layout.
#![allow(clippy::cast_possible_truncation)]
use candle_core::{Result, Tensor};
#[cfg(feature = "cuda")]
use candle_core::DType;
/// CUDA-accelerated gated delta rule recurrence.
///
/// Inputs (all contiguous, f32):
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
/// state: [BH, K, V] (mutated in place)
///
/// Returns: output [BH, S, V]
#[cfg(feature = "cuda")]
pub fn gated_delta_rule_recurrence_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: &mut Tensor,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
let (bh, seq_len, k_dim) = q.dims3()?;
let v_dim = v.dim(2)?;
let dev = q.device().as_cuda_device()?;
let (q_s, q_l) = q.storage_and_layout();
let q_s = match &*q_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("q must be a cuda tensor"),
};
let q_offset = q_l.start_offset();
let (k_s, k_l) = k.storage_and_layout();
let k_s = match &*k_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("k must be a cuda tensor"),
};
let k_offset = k_l.start_offset();
let (v_s, v_l) = v.storage_and_layout();
let v_s = match &*v_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("v must be a cuda tensor"),
};
let v_offset = v_l.start_offset();
let (g_s, g_l) = g.storage_and_layout();
let g_s = match &*g_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("g must be a cuda tensor"),
};
let g_offset = g_l.start_offset();
let (beta_s, beta_l) = beta.storage_and_layout();
let beta_s = match &*beta_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("beta must be a cuda tensor"),
};
let beta_offset = beta_l.start_offset();
let (state_s, state_l) = state.storage_and_layout();
let state_s = match &*state_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("state must be a cuda tensor"),
};
let state_offset = state_l.start_offset();
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::gated_delta_rule_recurrence(
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
bh as i32,
seq_len as i32,
k_dim as i32,
v_dim as i32,
stream,
);
}
// The kernel wrote state in-place via the raw pointer; rewrap
// (state tensor's underlying CudaSlice was modified directly)
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
Ok(Tensor::from((
candle::Storage::Cuda(output_storage),
(bh, seq_len, v_dim),
)))
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn gated_delta_rule_recurrence_cuda(
_q: &Tensor,
_k: &Tensor,
_v: &Tensor,
_g: &Tensor,
_beta: &Tensor,
_state: &mut Tensor,
) -> Result<Tensor> {
candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature")
}
/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization).
///
/// Processes prefill tokens in 64-token chunks instead of one at a time.
/// Same interface as `gated_delta_rule_recurrence_cuda`.
///
/// Inputs (all contiguous, f32):
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
/// state: [BH, K, V] (mutated in place)
///
/// Returns: output [BH, S, V]
#[cfg(feature = "cuda")]
pub fn chunked_gated_delta_rule_recurrence_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: &mut Tensor,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
let (bh, seq_len, k_dim) = q.dims3()?;
let v_dim = v.dim(2)?;
let dev = q.device().as_cuda_device()?;
let (q_s, q_l) = q.storage_and_layout();
let q_s = match &*q_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("q must be a cuda tensor"),
};
let q_offset = q_l.start_offset();
let (k_s, k_l) = k.storage_and_layout();
let k_s = match &*k_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("k must be a cuda tensor"),
};
let k_offset = k_l.start_offset();
let (v_s, v_l) = v.storage_and_layout();
let v_s = match &*v_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("v must be a cuda tensor"),
};
let v_offset = v_l.start_offset();
let (g_s, g_l) = g.storage_and_layout();
let g_s = match &*g_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("g must be a cuda tensor"),
};
let g_offset = g_l.start_offset();
let (beta_s, beta_l) = beta.storage_and_layout();
let beta_s = match &*beta_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("beta must be a cuda tensor"),
};
let beta_offset = beta_l.start_offset();
let (state_s, state_l) = state.storage_and_layout();
let state_s = match &*state_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("state must be a cuda tensor"),
};
let state_offset = state_l.start_offset();
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::chunked_gated_delta_rule_recurrence(
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
bh as i32,
seq_len as i32,
k_dim as i32,
v_dim as i32,
stream,
);
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
Ok(Tensor::from((
candle::Storage::Cuda(output_storage),
(bh, seq_len, v_dim),
)))
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn chunked_gated_delta_rule_recurrence_cuda(
_q: &Tensor,
_k: &Tensor,
_v: &Tensor,
_g: &Tensor,
_beta: &Tensor,
_state: &mut Tensor,
) -> Result<Tensor> {
candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature")
}
/// CUDA-accelerated causal conv1d (both update and full paths).
///
/// For update (is_update=true):
/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update)
/// Returns: (output [B, conv_dim, 1], updated conv_state)
///
/// For full (is_update=false):
/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size])
#[cfg(feature = "cuda")]
pub fn causal_conv1d_cuda(
x: &Tensor,
weight: &Tensor,
conv_state: &Tensor,
kernel_size: usize,
is_update: bool,
) -> Result<(Tensor, Tensor)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use core::ffi::c_void;
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
x: &Tensor,
weight: &Tensor,
conv_state: &Tensor,
kernel_size: usize,
is_update: bool,
dtype_code: i32,
) -> Result<(Tensor, Tensor)> {
let dev = x.device().as_cuda_device()?;
let (batch_size, conv_dim, seq_len) = x.dims3()?;
let (x_s, x_l) = x.storage_and_layout();
let x_s = match &*x_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("x must be a cuda tensor"),
};
let x_offset = x_l.start_offset();
let (w_s, w_l) = weight.storage_and_layout();
let w_s = match &*w_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("weight must be a cuda tensor"),
};
let w_offset = w_l.start_offset();
let stream = dev.cuda_stream().cu_stream() as i64;
if is_update {
// Clone conv_state so the kernel can mutate it in place
let conv_state_new = conv_state.clone();
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim) }?;
// Scope the borrow of conv_state_new so we can move it later
{
let (cs_s, cs_l) = conv_state_new.storage_and_layout();
let cs_s = match &*cs_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("conv_state must be a cuda tensor"),
};
let cs_offset = cs_l.start_offset();
unsafe {
crate::cuda::ffi::causal_conv1d_update(
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void,
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
batch_size as i32,
conv_dim as i32,
kernel_size as i32,
dtype_code,
stream,
);
}
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
let output = Tensor::from((
candle::Storage::Cuda(output_storage),
(batch_size, conv_dim, 1usize),
));
Ok((output, conv_state_new))
} else {
// Full path: allocate new conv_state and output
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * seq_len) }?;
let cs_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * kernel_size) }?;
unsafe {
crate::cuda::ffi::causal_conv1d_full(
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void,
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
batch_size as i32,
conv_dim as i32,
seq_len as i32,
kernel_size as i32,
dtype_code,
stream,
);
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
let output = Tensor::from((
candle::Storage::Cuda(output_storage),
(batch_size, conv_dim, seq_len),
));
let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone());
let new_conv_state = Tensor::from((
candle::Storage::Cuda(cs_storage),
(batch_size, conv_dim, kernel_size),
));
Ok((output, new_conv_state))
}
}
match x.dtype() {
DType::F16 => cuda_fwd::<half::f16>(x, weight, conv_state, kernel_size, is_update, 0),
DType::BF16 => cuda_fwd::<half::bf16>(x, weight, conv_state, kernel_size, is_update, 1),
other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn causal_conv1d_cuda(
_x: &Tensor,
_weight: &Tensor,
_conv_state: &Tensor,
_kernel_size: usize,
_is_update: bool,
) -> Result<(Tensor, Tensor)> {
candle_core::bail!("causal_conv1d_cuda requires the cuda feature")
}
/// CUDA-accelerated fused GDN gating computation.
///
/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
///
/// b, a: [total_elements] in f16/bf16
/// a_log, dt_bias: [num_heads] in f32
///
/// Returns: (beta, g) in original dtype
#[cfg(feature = "cuda")]
pub fn fused_gdn_gating_cuda(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> Result<(Tensor, Tensor)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use core::ffi::c_void;
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
dtype_code: i32,
) -> Result<(Tensor, Tensor)> {
let total_elements = b.elem_count();
let num_heads = a_log.elem_count();
let shape = b.shape().clone();
let dev = b.device().as_cuda_device()?;
let (b_s, b_l) = b.storage_and_layout();
let b_s = match &*b_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("b must be a cuda tensor"),
};
let b_offset = b_l.start_offset();
let (a_s, a_l) = a.storage_and_layout();
let a_s = match &*a_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("a must be a cuda tensor"),
};
let a_offset = a_l.start_offset();
let (alog_s, alog_l) = a_log.storage_and_layout();
let alog_s = match &*alog_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("a_log must be a cuda tensor"),
};
let alog_offset = alog_l.start_offset();
let (dtb_s, dtb_l) = dt_bias.storage_and_layout();
let dtb_s = match &*dtb_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("dt_bias must be a cuda tensor"),
};
let dtb_offset = dtb_l.start_offset();
let beta_buf = unsafe { dev.alloc::<T>(total_elements) }?;
let g_buf = unsafe { dev.alloc::<T>(total_elements) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::fused_gdn_gating(
b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void,
a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void,
alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32,
dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32,
beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void,
g_buf.device_ptr(g_buf.stream()).0 as *mut c_void,
total_elements as i32,
num_heads as i32,
dtype_code,
stream,
);
}
let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone());
let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone()));
let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone());
let g = Tensor::from((candle::Storage::Cuda(g_storage), shape));
Ok((beta, g))
}
match b.dtype() {
DType::F16 => cuda_fwd::<half::f16>(b, a, a_log, dt_bias, 0),
DType::BF16 => cuda_fwd::<half::bf16>(b, a, a_log, dt_bias, 1),
other => candle_core::bail!(
"fused_gdn_gating_cuda only supports f16/bf16, got {:?}",
other
),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn fused_gdn_gating_cuda(
_b: &Tensor,
_a: &Tensor,
_a_log: &Tensor,
_dt_bias: &Tensor,
) -> Result<(Tensor, Tensor)> {
candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature")
}

View File

@@ -0,0 +1,15 @@
//! CUDA kernels and their Rust wrappers.
//!
//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`)
//! inference performance — the Gated DeltaNet kernels ported from
//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu`
//! file alongside this module; `build.rs` compiles them all into a
//! static lib via `cudaforge` and links it under the `cuda` feature.
//!
//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM,
//! etc.) they land here in their own `.cu` + `.rs` pairs.
#[cfg(feature = "cuda")]
pub mod ffi;
#[cfg(feature = "cuda")]
pub mod gdn;

View File

@@ -0,0 +1,23 @@
//! Custom architecture implementations.
//!
//! When candle-transformers ships a model family unchanged
//! (`models::llama`, `models::qwen3`, `models::qwen3_moe`, etc.), the
//! handler in `harness/candle.rs` just wraps the upstream type in a
//! `ModelArch` variant.
//!
//! When candle has nothing for the architecture and we have to write
//! it from scratch — Qwen3-Next / Qwen3.6 (`qwen3_5`) being the
//! motivating example — the implementation lands here, one file per
//! architecture.
//!
//! Each architecture module is expected to expose:
//! - A `Config` type deserialised from the model's `config.json`
//! (some architectures nest the real hyperparams under `text_config`,
//! in which case the module owns the unwrapping).
//! - A `ForCausalLM` struct with `new`, `forward(&mut self, x, offset)
//! -> Result<Tensor>`, and `clear_kv_cache(&mut self)`.
//!
//! TP-aware analogues live in `harness/tp/tp_<family>.rs` and follow
//! the pattern set by `tp_qwen3.rs`.
pub mod qwen3_5;

View File

@@ -0,0 +1,117 @@
//! Qwen3-Next decoder layer.
//!
//! Standard pre-norm transformer block (LN → attention → residual →
//! LN → MLP → residual) where the attention slot dispatches on the
//! per-layer `layer_types[i]` value in the config:
//!
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
//! gate + RoPE + KV cache).
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
//! causal conv + per-head state).
//!
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
//! linear_attention. `full_attention_interval` in the config is a
//! hint; `layer_types` is authoritative.
use anyhow::Result;
use candle_core::{Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use std::sync::Arc;
use super::TextConfig;
use super::full_attn::Qwen3_5Attention;
use super::linear_attn::GatedDeltaNet;
use super::mlp::Qwen3_5MLP;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
/// One of the two attention flavours sitting in a decoder layer's
/// attention slot. Full-attention layers need the rotary table and
/// take an attention mask; linear-attention layers carry their own
/// recurrent state and ignore the mask.
enum AttentionKind {
Full(Qwen3_5Attention),
Linear(GatedDeltaNet),
}
pub struct Qwen3_5DecoderLayer {
input_layernorm: Qwen3_5RmsNorm,
post_attention_layernorm: Qwen3_5RmsNorm,
mlp: Qwen3_5MLP,
attention: AttentionKind,
}
impl Qwen3_5DecoderLayer {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
layer_idx: usize,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let layer_type = cfg
.layer_types
.get(layer_idx)
.map(String::as_str)
.ok_or_else(|| {
anyhow::anyhow!(
"layer_types[{layer_idx}] missing (have {} entries)",
cfg.layer_types.len()
)
})?;
let attention = match layer_type {
"full_attention" => {
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
}
"linear_attention" => {
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
}
other => anyhow::bail!(
"unknown layer_type '{other}' for layer {layer_idx} (expected \
'full_attention' or 'linear_attention')"
),
};
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
let input_layernorm =
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let post_attention_layernorm = Qwen3_5RmsNorm::load(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
input_layernorm,
post_attention_layernorm,
mlp,
attention,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let h = self.input_layernorm.forward(x)?;
let attn_out = match &mut self.attention {
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
// Linear attention ignores attn_mask + offset; its causal
// structure is baked into the recurrent state lifecycle.
AttentionKind::Linear(net) => net.forward(&h)?,
};
let x = (x + attn_out)?;
let h2 = self.post_attention_layernorm.forward(&x)?;
let h2 = self.mlp.forward(&h2)?;
x + h2
}
pub fn clear_kv_cache(&mut self) {
match &mut self.attention {
AttentionKind::Full(attn) => attn.clear_kv_cache(),
AttentionKind::Linear(net) => net.clear_kv_cache(),
}
}
}

View File

@@ -0,0 +1,179 @@
//! Qwen3-Next's `full_attention` layer.
//!
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
//!
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
//! to `num_heads * head_dim * 2`. The second half is reshaped to
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
//! attention output is pointwise-multiplied by this gate before
//! `o_proj`. Effectively a per-head per-position attenuation on
//! the attention output.
//!
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
//! checkpoints expect the `(1 + w)` form.
//!
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
//! `rope::RotaryEmbedding`), and the usual causal mask.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::kv_cache::ConcatKvCache;
use candle_nn::var_builder::ShardedVarBuilder;
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;
use super::TextConfig;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
pub struct Qwen3_5Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: Qwen3_5RmsNorm,
k_norm: Qwen3_5RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary: Arc<RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl Qwen3_5Attention {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let head_dim = cfg.head_dim;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
anyhow::bail!(
"num_attention_heads ({num_heads}) must be a positive multiple of \
num_key_value_heads ({num_kv_heads})"
);
}
let num_kv_groups = num_heads / num_kv_heads;
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
// the gate (see attn_output_gate notes above).
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
let hidden_size = head_dim * num_heads;
let kv_cache = ConcatKvCache::new(2);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size,
rotary,
kv_cache,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
// 1. q_proj — widened output, split into (query, gate).
let q_raw = self
.q_proj
.forward(x)?
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
let q = q_raw.narrow(3, 0, self.head_dim)?;
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
// Flatten the gate's head dim back into hidden_size for the
// post-attention pointwise multiply.
let gate = gate
.contiguous()?
.reshape((b, l, self.num_heads * self.head_dim))?;
// 2. q_norm + k_norm + reshape to (B, H, L, D).
let q = self.q_norm.forward(&q.contiguous()?)?;
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
let k = self
.k_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
let k = self.k_norm.forward(&k.contiguous()?)?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = self
.v_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
// 3. RoPE on q, k.
let (q, k) = self.rotary.apply(&q, &k, offset)?;
// 4. KV cache.
let (k, v) = self.kv_cache.append(&k, &v)?;
// 5. GQA repeat (cheap shape op).
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
// 6. Scaled dot-product + causal mask.
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?; // (B, H, L, D)
// 7. Reshape back, apply the output gate, project.
let ctx = ctx
.transpose(1, 2)?
.contiguous()?
.reshape((b, l, self.hidden_size))?;
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
let gated = (ctx * gate_sig)?;
self.o_proj.forward(&gated)
}
pub fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

View File

@@ -0,0 +1,793 @@
//! Qwen3-Next's `linear_attention` layer: Gated DeltaNet.
//!
//! The recurrent linear-attention block that occupies 3 out of every 4
//! decoder layers in Qwen3.6 (`layer_types[i] == "linear_attention"`).
//! Implemented against the reference Python in
//! `huggingface/transformers/src/transformers/models/qwen3_5/modeling_qwen3_5.py`
//! (class `Qwen3_5GatedDeltaNet`).
//!
//! ## Block structure
//!
//! ```text
//! x ── in_proj_qkv ── transpose ─► (B, conv_dim, L)
//! │
//! ┌──────────────── conv_state ──┤ prepend cached state (decode)
//! ▼
//! depthwise causal Conv1d (k=4) → SiLU
//! │
//! └─ split → q (k_dim), k (k_dim), v (v_dim) ─► per-head reshape
//!
//! x ── in_proj_z ────────────────► z (gate for the output RMSNorm)
//! x ── in_proj_b ── sigmoid ─────► beta (per-head per-token update rate)
//! x ── in_proj_a ── softplus ────► g (decay; see eqn below)
//!
//! g = -exp(A_log) * softplus(a + dt_bias) # discretisation
//! beta = sigmoid(b)
//!
//! (q, k) ─── L2norm ─── delta rule loop ──── core_attn_out
//! (per-token, per-head):
//! state *= exp(g_t)
//! mem = state^T · k_t
//! delta = (v_t - mem) * beta_t
//! state += outer(k_t, delta)
//! out_t = state^T · q_t
//!
//! core_attn_out ── RMSNormGated(z) ── reshape ── out_proj ── y
//! ```
//!
//! ## State
//!
//! Two tensors persist across decode steps:
//! - `conv_state`: `(B, conv_dim, conv_kernel_size)` — left-padded
//! tail of the input to the depthwise conv, so the next causal
//! window has the right left-context.
//! - `recurrent_state`: `(B, num_v_heads, head_k_dim, head_v_dim)` —
//! the delta-rule outer-product memory.
//!
//! Both are cleared via [`GatedDeltaNet::clear_kv_cache`] at the start
//! of every new request.
//!
//! ## Performance note
//!
//! This impl is the **recurrent** delta-rule for both prefill and
//! decode — i.e. the algorithm in `torch_recurrent_gated_delta_rule`.
//! Correctness-first. The chunked algorithm (chunk_size=64) in
//! `torch_chunk_gated_delta_rule` is a perf optimisation for long
//! prefill; can be added later without changing the surface.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
#[cfg(test)]
use super::RopeParameters;
use super::TextConfig;
use super::rmsnorm::{Qwen3_5RmsNormGated, l2norm};
/// Per-rank, per-layer state for the linear-attention block.
///
/// `conv_state` is left-padded with zeros on first use; `recurrent_state`
/// is initialised lazily to zeros once we know the batch size.
#[derive(Default)]
pub struct GatedDeltaNetState {
pub conv_state: Option<Tensor>,
pub recurrent_state: Option<Tensor>,
}
pub struct GatedDeltaNet {
// Projections.
in_proj_qkv: Linear,
in_proj_z: Linear,
in_proj_b: Linear,
in_proj_a: Linear,
out_proj: Linear,
// Depthwise causal Conv1d weight; shape (conv_dim, 1, kernel_size).
// No bias (Python sets bias=False).
conv1d_weight: Tensor,
// Per-head discretisation params.
dt_bias: Tensor,
a_log: Tensor,
// Output norm + gate.
norm: Qwen3_5RmsNormGated,
// Shape hyperparams (cached for forward).
num_v_heads: usize,
num_k_heads: usize,
head_k_dim: usize,
head_v_dim: usize,
key_dim: usize,
value_dim: usize,
conv_dim: usize,
conv_kernel_size: usize,
// Recurrent state held inline. Each request resets via
// `clear_kv_cache`; otherwise the state persists across forwards
// and the per-token offset advances naturally.
state: GatedDeltaNetState,
}
impl GatedDeltaNet {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let num_v_heads = cfg.linear_num_value_heads;
let num_k_heads = cfg.linear_num_key_heads;
let head_k_dim = cfg.linear_key_head_dim;
let head_v_dim = cfg.linear_value_head_dim;
let conv_kernel_size = cfg.linear_conv_kernel_dim;
if num_v_heads == 0 || num_k_heads == 0 {
anyhow::bail!(
"Qwen3-Next linear_num_*_heads must be set; got v={num_v_heads}, k={num_k_heads}"
);
}
if !num_v_heads.is_multiple_of(num_k_heads) {
anyhow::bail!(
"linear_num_value_heads ({num_v_heads}) must be a multiple of \
linear_num_key_heads ({num_k_heads}) for GQA-style head expansion"
);
}
let key_dim = head_k_dim * num_k_heads;
let value_dim = head_v_dim * num_v_heads;
let conv_dim = key_dim * 2 + value_dim;
// ----- Linear projections (all `bias=False` in the reference). -----
let in_proj_qkv = load_linear_no_bias(vb, "in_proj_qkv", cfg.hidden_size, conv_dim)?;
let in_proj_z = load_linear_no_bias(vb, "in_proj_z", cfg.hidden_size, value_dim)?;
let in_proj_b = load_linear_no_bias(vb, "in_proj_b", cfg.hidden_size, num_v_heads)?;
let in_proj_a = load_linear_no_bias(vb, "in_proj_a", cfg.hidden_size, num_v_heads)?;
let out_proj = load_linear_no_bias(vb, "out_proj", value_dim, cfg.hidden_size)?;
// ----- Conv1d weight (depthwise, bias=False). -----
let conv1d_weight = vb
.pp("conv1d")
.get((conv_dim, 1, conv_kernel_size), "weight")
.with_context(|| format!("load '{}/conv1d/weight'", vb.prefix()))?;
// ----- dt_bias + A_log: per-head 1D params. -----
let dt_bias = vb
.get(num_v_heads, "dt_bias")
.with_context(|| format!("load '{}/dt_bias'", vb.prefix()))?;
let a_log = vb
.get(num_v_heads, "A_log")
.with_context(|| format!("load '{}/A_log'", vb.prefix()))?;
// ----- Output gated RMSNorm (per-head_v_dim). -----
let norm = Qwen3_5RmsNormGated::load(&vb.pp("norm"), head_v_dim, cfg.rms_norm_eps)?;
Ok(Self {
in_proj_qkv,
in_proj_z,
in_proj_b,
in_proj_a,
out_proj,
conv1d_weight,
dt_bias,
a_log,
norm,
num_v_heads,
num_k_heads,
head_k_dim,
head_v_dim,
key_dim,
value_dim,
conv_dim,
conv_kernel_size,
state: GatedDeltaNetState::default(),
})
}
pub fn clear_kv_cache(&mut self) {
self.state = GatedDeltaNetState::default();
}
/// `x` shape: `(B, L, hidden_size)`. Returns the same shape.
pub fn forward(&mut self, x: &Tensor) -> candle_core::Result<Tensor> {
let (batch_size, seq_len, _) = x.dims3()?;
let dtype = x.dtype();
let device = x.device().clone();
// ----- Projections. -----
// mixed_qkv: (B, L, conv_dim)
let mixed_qkv = self.in_proj_qkv.forward(x)?;
// (B, conv_dim, L) for the conv1d.
let mixed_qkv_chw = mixed_qkv.transpose(1, 2)?.contiguous()?;
// z: (B, L, value_dim) → (B, L, num_v_heads, head_v_dim)
let z = self.in_proj_z.forward(x)?.reshape((
batch_size,
seq_len,
self.num_v_heads,
self.head_v_dim,
))?;
// b, a: (B, L, num_v_heads)
let b = self.in_proj_b.forward(x)?;
let a = self.in_proj_a.forward(x)?;
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
// Dispatches to a cuda kernel that fuses conv1d + silu when
// available; falls back to candle's `conv1d` + `silu` on cpu.
let (conv_out, new_state) = run_causal_conv1d(
&mixed_qkv_chw,
&self.conv1d_weight,
self.state.conv_state.take(),
batch_size,
self.conv_dim,
seq_len,
self.conv_kernel_size,
)?;
self.state.conv_state = Some(new_state);
// Back to (B, L, conv_dim).
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
// ----- Split into q, k, v. -----
let q = mixed_qkv.narrow(2, 0, self.key_dim)?;
let k = mixed_qkv.narrow(2, self.key_dim, self.key_dim)?;
let v = mixed_qkv.narrow(2, 2 * self.key_dim, self.value_dim)?;
let q = q.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
let k = k.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?;
// ----- beta + g (per-head, per-token gates). -----
// Fused on cuda; per-op Rust on cpu. Both paths produce:
// beta = sigmoid(b)
// g = -exp(A_log) * softplus(a + dt_bias)
let (beta, g) = run_fused_gating(&b, &a, &self.a_log, &self.dt_bias)?;
// ----- GQA-style key expansion if num_v_heads > num_k_heads. -----
let (q, k) = if self.num_v_heads > self.num_k_heads {
let rep = self.num_v_heads / self.num_k_heads;
(
repeat_interleave(&q, rep, 2)?,
repeat_interleave(&k, rep, 2)?,
)
} else {
(q, k)
};
// ----- L2-norm on q, k (use_qk_l2norm_in_kernel=True in ref). -----
let q = l2norm(&q, 1e-6)?;
let k = l2norm(&k, 1e-6)?;
// ----- Recurrent delta rule. -----
// Inputs: q, k (B, L, H, D_k); v (B, L, H, D_v); g (B, L, H); beta (B, L, H).
// The reference transposes to (B, H, L, D) before the loop. We
// do the same — it makes per-token indexing trivial.
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D_k)
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?; // (B, H, L, D_v)
let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L)
let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L)
// Pre-scale q by 1/sqrt(D_k) once. Everything goes to f32 here
// since the delta rule mixes broadcast_mul ops that candle won't
// accept across mixed dtypes. On the cuda gating path both beta
// and g come back in model dtype; on the cpu path g is already
// f32 — both casts are cheap idempotent ops.
let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt();
let q = (q.to_dtype(candle_core::DType::F32)? * scale)?;
let k = k.to_dtype(candle_core::DType::F32)?;
let v = v.to_dtype(candle_core::DType::F32)?;
let g = g.to_dtype(candle_core::DType::F32)?;
let beta = beta.to_dtype(candle_core::DType::F32)?;
// Initialise the recurrent state from cache or zeros.
let state_init = match self.state.recurrent_state.take() {
Some(s) => s.to_dtype(candle_core::DType::F32)?,
None => Tensor::zeros(
(
batch_size,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
),
candle_core::DType::F32,
&device,
)?,
};
// The delta-rule body: cuda-accelerated `gated_delta_rule_recurrence`
// kernel when we have a cuda device + the kernels are linked in,
// pure-Rust per-token fallback otherwise.
let (core_attn_out, new_state) = run_delta_rule(
&q,
&k,
&v,
&g,
&beta,
state_init,
batch_size,
self.num_v_heads,
seq_len,
self.head_k_dim,
self.head_v_dim,
)?;
// Stash the updated recurrent state for the next call.
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
let core_attn_out = core_attn_out.to_dtype(dtype)?;
let core_attn_flat =
core_attn_out.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
let z_flat = z.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
// RMSNormGated: (out * silu(z) * weight) with the norm.
let normed = self.norm.forward(&core_attn_flat, &z_flat)?;
let normed = normed.reshape((batch_size, seq_len, self.num_v_heads * self.head_v_dim))?;
// Output projection: (B, L, value_dim) → (B, L, hidden_size).
self.out_proj.forward(&normed)
}
}
/// Run the per-token delta-rule recurrence.
///
/// `q`, `k`: `(B, H, L, D_k)` (F32). `v`: `(B, H, L, D_v)`. `g`,
/// `beta`: `(B, H, L)`. `state`: `(B, H, D_k, D_v)`.
///
/// Returns `(core_attn_out: (B, H, L, D_v), state: (B, H, D_k, D_v))`,
/// both F32. Caller is responsible for cast back to model dtype.
///
/// Cuda path: dispatches to the `gated_delta_rule_recurrence` kernel
/// ported from `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/gdn.cu`.
/// All five inputs must be cuda f32 tensors. The kernel is V-tiled
/// with compile-time BK; one block per (V-tile, batch*head) and one
/// thread per V-column. Each thread holds BK state floats in
/// registers — eliminates the launch-overhead floor we hit with
/// candle's per-op dispatch (was ~12s/token on Qwen3.6-27B).
///
/// CPU path: pure-Rust per-token loop. Correct, slow.
#[allow(clippy::too_many_arguments)]
pub(crate) fn run_delta_rule(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: Tensor,
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_k_dim: usize,
head_v_dim: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
// Only dispatch to the kernel if the inputs are on a CUDA
// device — CPU tests fall back to the Rust loop below.
if q.device().is_cuda() {
return run_delta_rule_cuda(
q, k, v, g, beta, state, batch_size, num_heads, seq_len, head_k_dim, head_v_dim,
);
}
}
let _ = (batch_size, num_heads, head_k_dim, head_v_dim);
run_delta_rule_rust(q, k, v, g, beta, state, seq_len)
}
/// CUDA path. Flattens (B, H, ...) → (BH, ...) at the kernel boundary
/// (the kernel uses BH = batch*heads as its outer batch axis) and
/// reshapes the kernel's outputs back to (B, H, ...) for the caller.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
fn run_delta_rule_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: Tensor,
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_k_dim: usize,
head_v_dim: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let q_bh = q.flatten(0, 1)?.contiguous()?;
let k_bh = k.flatten(0, 1)?.contiguous()?;
let v_bh = v.flatten(0, 1)?.contiguous()?;
let g_bh = g.flatten(0, 1)?.contiguous()?;
let beta_bh = beta.flatten(0, 1)?.contiguous()?;
let mut state_bh = state.flatten(0, 1)?.contiguous()?;
// For long prefills, the chunked kernel (BT=64) processes a chunk
// of tokens at a time instead of one-by-one — same delta-rule math,
// far fewer block launches. Threshold matches mistralrs.
const CHUNK_THRESHOLD: usize = 64;
let output_bh = if seq_len >= CHUNK_THRESHOLD {
crate::cuda::gdn::chunked_gated_delta_rule_recurrence_cuda(
&q_bh,
&k_bh,
&v_bh,
&g_bh,
&beta_bh,
&mut state_bh,
)?
} else {
crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
&q_bh,
&k_bh,
&v_bh,
&g_bh,
&beta_bh,
&mut state_bh,
)?
};
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
Ok((core_attn_out, new_state))
}
#[allow(clippy::too_many_arguments)]
fn run_delta_rule_rust(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
mut state: Tensor,
seq_len: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
use candle_core::IndexOp;
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let q_t = q.i((.., .., t, ..))?;
let k_t = k.i((.., .., t, ..))?;
let v_t = v.i((.., .., t, ..))?;
let g_t = g.i((.., .., t))?;
let beta_t = beta.i((.., .., t))?;
let decay = g_t
.exp()?
.unsqueeze(candle_core::D::Minus1)?
.unsqueeze(candle_core::D::Minus1)?;
state = state.broadcast_mul(&decay)?;
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?;
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?;
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
let delta_row = delta.unsqueeze(2)?;
let outer = k_col.broadcast_mul(&delta_row)?;
state = (state + outer)?;
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?;
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
outputs.push(out_t.unsqueeze(2)?);
}
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
Ok((core_attn_out, state))
}
/// Depthwise causal conv1d + SiLU, with rolling `conv_state`.
///
/// `x`: `(B, conv_dim, L)` model dtype (f16/bf16 on cuda, anything on cpu).
/// `weight`: `(conv_dim, 1, kernel_size)` model dtype.
/// `conv_state`: `Some((B, conv_dim, kernel_size))` for decode continuation,
/// or `None` for fresh prefill.
///
/// Returns `(conv_out: (B, conv_dim, L), new_conv_state: (B, conv_dim, kernel_size))`.
/// SiLU is baked in.
///
/// Cuda path: dispatches to `causal_conv1d_update` (decode, seq_len=1 with
/// existing state) or `causal_conv1d_full` (prefill / first call), both
/// ported from mistralrs `gdn.cu`. Each kernel fuses the depthwise conv
/// and SiLU activation in one launch — that's ~4× fewer cuda launches per
/// linear-attention layer than the candle `conv1d` + `silu` combo.
///
/// CPU path: the original prepend-narrow-conv1d-silu sequence.
pub(crate) fn run_causal_conv1d(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
if x.device().is_cuda() {
return run_causal_conv1d_cuda(
x,
weight,
conv_state,
batch_size,
conv_dim,
seq_len,
conv_kernel_size,
);
}
}
run_causal_conv1d_rust(
x,
weight,
conv_state,
batch_size,
conv_dim,
seq_len,
conv_kernel_size,
)
}
#[cfg(feature = "cuda")]
fn run_causal_conv1d_cuda(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
// Kernel expects weight as (conv_dim, kernel_size) — squeeze the
// depthwise channel-multiplier dim.
let w = weight.squeeze(1)?.to_dtype(x.dtype())?.contiguous()?;
// Decode path: seq_len == 1 AND we have an existing conv_state.
// Otherwise (prefill or fresh-start decode), use the full path which
// zero-pads on the left internally.
if let Some(cs) = conv_state
&& seq_len == 1
{
let cs = cs.contiguous()?;
let (output, new_conv_state) =
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &cs, conv_kernel_size, true)?;
return Ok((output, new_conv_state));
}
// Prefill / fresh-start: the kernel ignores any prior conv_state and
// zero-pads. If we had a non-zero prior state and >1 input tokens
// (multi-turn continuation), we'd need to fall back to Rust. Match
// mistralrs's behaviour: fresh prefill always.
let device = x.device().clone();
let zeros_cs = Tensor::zeros((batch_size, conv_dim, conv_kernel_size), x.dtype(), &device)?;
let (output, new_conv_state) =
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &zeros_cs, conv_kernel_size, false)?;
Ok((output, new_conv_state))
}
/// Fused GDN gating: computes `beta = sigmoid(b)` and
/// `g = -exp(a_log) * softplus(a + dt_bias)` together.
///
/// `b`, `a`: `(B, L, num_heads)` model dtype.
/// `a_log`, `dt_bias`: `(num_heads,)` model dtype (cast to f32 internally).
///
/// Returns `(beta, g)` both in model dtype on the cuda path, both in f32
/// on the cpu fallback. The caller casts to f32 before the delta rule.
///
/// Cuda path: dispatches to `fused_gdn_gating_cuda` — one kernel
/// replaces sigmoid + neg(exp) + softplus + broadcast_mul (≈10 candle
/// launches per layer).
pub(crate) fn run_fused_gating(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
if b.device().is_cuda() {
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?.contiguous()?;
let dt_bias_f32 = dt_bias.to_dtype(candle_core::DType::F32)?.contiguous()?;
return crate::cuda::gdn::fused_gdn_gating_cuda(b, a, &a_log_f32, &dt_bias_f32);
}
}
run_fused_gating_rust(b, a, a_log, dt_bias)
}
fn run_fused_gating_rust(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> candle_core::Result<(Tensor, Tensor)> {
let beta = candle_nn::ops::sigmoid(b)?;
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?;
let neg_a_exp = a_log_f32.exp()?.neg()?;
let dt_b_f32 = dt_bias.to_dtype(candle_core::DType::F32)?;
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
let softplus_val = softplus(&a_plus_dt)?;
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
let g = neg_a_exp_b.broadcast_mul(&softplus_val)?;
Ok((beta, g))
}
fn run_causal_conv1d_rust(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let dtype = x.dtype();
let device = x.device().clone();
let prepended = match &conv_state {
Some(prev) => Tensor::cat(&[prev, x], 2)?,
None => x.clone(),
};
let prep_len = prepended.dims()[2];
let new_state = if prep_len >= conv_kernel_size {
prepended.narrow(2, prep_len - conv_kernel_size, conv_kernel_size)?
} else {
let pad = Tensor::zeros(
(batch_size, conv_dim, conv_kernel_size - prep_len),
dtype,
&device,
)?;
Tensor::cat(&[&pad, &prepended], 2)?
};
let conv_out = prepended.conv1d(weight, conv_kernel_size - 1, 1, 1, conv_dim)?;
let conv_out = conv_out.narrow(2, 0, prep_len)?;
let conv_out = candle_nn::ops::silu(&conv_out)?;
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
Ok((conv_out, new_state))
}
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
/// the standard `[out, in]` order.
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}
/// Numerically-stable `softplus(x) = ln(1 + exp(x))`. Matches PyTorch's
/// `F.softplus` default (beta=1, threshold=20: for large positive x,
/// returns x as-is to avoid overflow in the exp).
pub(crate) fn softplus(x: &Tensor) -> candle_core::Result<Tensor> {
let threshold = 20.0_f64;
let big = x.ge(threshold)?; // Tensor<u8> mask
let safe = x.minimum(&x.affine(0.0, 0.0)?.affine(0.0, threshold)?)?; // min(x, threshold)
let small = ((safe.exp()? + 1.0_f64)?).log()?;
// Select x where big, else small.
big.where_cond(x, &small)
}
/// `repeat_interleave` along a single dim. Candle has no built-in for
/// this; emulate with unsqueeze + expand + reshape.
pub(crate) fn repeat_interleave(
x: &Tensor,
repeats: usize,
dim: usize,
) -> candle_core::Result<Tensor> {
if repeats == 1 {
return Ok(x.clone());
}
let mut shape = x.dims().to_vec();
let orig = shape[dim];
shape.insert(dim + 1, repeats);
let mut expanded_shape = shape.clone();
expanded_shape[dim + 1] = repeats;
let x = x.unsqueeze(dim + 1)?;
let x = x.expand(expanded_shape)?;
let mut out_shape = x.dims().to_vec();
out_shape.remove(dim + 1);
out_shape[dim] = orig * repeats;
x.contiguous()?.reshape(out_shape)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
#[test]
fn softplus_small_x() {
// softplus(0) = ln(2) ≈ 0.6931
let x = Tensor::new(&[0.0_f32], &Device::Cpu).unwrap();
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
assert!((out[0] - 2.0_f32.ln()).abs() < 1e-4);
}
#[test]
fn softplus_large_x_returns_x() {
// For x = 30, softplus(x) ≈ x (the threshold branch).
let x = Tensor::new(&[30.0_f32], &Device::Cpu).unwrap();
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
assert!((out[0] - 30.0).abs() < 1e-4);
}
#[test]
fn repeat_interleave_doubles_dim() {
let x = Tensor::new(&[[1.0_f32, 2.0], [3.0, 4.0]], &Device::Cpu).unwrap(); // shape (2, 2)
let out = repeat_interleave(&x, 2, 1).unwrap(); // each col duplicated
let v: Vec<Vec<f32>> = out.to_vec2().unwrap();
// Row 0: 1, 1, 2, 2
// Row 1: 3, 3, 4, 4
assert_eq!(v[0], vec![1.0, 1.0, 2.0, 2.0]);
assert_eq!(v[1], vec![3.0, 3.0, 4.0, 4.0]);
}
/// Sanity: the recurrent path produces a finite tensor of the right
/// shape on tiny dimensions. Doesn't validate numerical correctness
/// against the Python reference — that would need a fixed-weight
/// fixture to compare against. Catches structural mistakes
/// (broadcasting shapes, off-by-one slices) early.
#[test]
fn forward_smoke_with_tiny_dimensions() {
let dev = Device::Cpu;
let dtype = DType::F32;
let (b, l) = (1, 3);
let cfg = TextConfig {
vocab_size: 100,
hidden_size: 16,
intermediate_size: 32,
num_hidden_layers: 1,
num_attention_heads: 4,
num_key_value_heads: 1,
head_dim: 4,
max_position_embeddings: 32,
rope_parameters: RopeParameters {
rope_theta: 10000.0,
partial_rotary_factor: 1.0,
rope_type: None,
},
rms_norm_eps: 1e-6,
tie_word_embeddings: false,
attn_output_gate: true,
layer_types: vec!["linear_attention".into()],
full_attention_interval: Some(4),
hidden_act: "silu".into(),
linear_num_value_heads: 4,
linear_num_key_heads: 2,
linear_key_head_dim: 4,
linear_value_head_dim: 4,
linear_conv_kernel_dim: 4,
};
// Build a synthetic VarBuilder with all-zeros weights.
// Easier path: skip the load and construct GatedDeltaNet
// manually by hand-rolling the Linear/Tensor inputs.
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &dev).unwrap();
let key_dim = cfg.linear_key_head_dim * cfg.linear_num_key_heads;
let value_dim = cfg.linear_value_head_dim * cfg.linear_num_value_heads;
let conv_dim = key_dim * 2 + value_dim;
let mut net = GatedDeltaNet {
in_proj_qkv: Linear::new(zeros(&[conv_dim, cfg.hidden_size]), None),
in_proj_z: Linear::new(zeros(&[value_dim, cfg.hidden_size]), None),
in_proj_b: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
in_proj_a: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
out_proj: Linear::new(zeros(&[cfg.hidden_size, value_dim]), None),
conv1d_weight: zeros(&[conv_dim, 1, cfg.linear_conv_kernel_dim]),
dt_bias: zeros(&[cfg.linear_num_value_heads]),
a_log: zeros(&[cfg.linear_num_value_heads]),
norm: {
let weight = Tensor::ones(&[cfg.linear_value_head_dim], dtype, &dev).unwrap();
Qwen3_5RmsNormGated::from_weight(weight, cfg.rms_norm_eps)
},
num_v_heads: cfg.linear_num_value_heads,
num_k_heads: cfg.linear_num_key_heads,
head_k_dim: cfg.linear_key_head_dim,
head_v_dim: cfg.linear_value_head_dim,
key_dim,
value_dim,
conv_dim,
conv_kernel_size: cfg.linear_conv_kernel_dim,
state: GatedDeltaNetState::default(),
};
let x = Tensor::ones(&[b, l, cfg.hidden_size], dtype, &dev).unwrap();
let y = net.forward(&x).unwrap();
assert_eq!(y.dims(), &[b, l, cfg.hidden_size]);
// All zero weights → output should be zero. Confirms no NaN/Inf
// poisoning from the f32 promotions.
let v: Vec<f32> = y.flatten_all().unwrap().to_vec1().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
}

View File

@@ -0,0 +1,53 @@
//! SwiGLU MLP block for Qwen3-Next.
//!
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
//! no bias on any of the three projections.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use super::TextConfig;
pub struct Qwen3_5MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl Qwen3_5MLP {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let i = cfg.intermediate_size;
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
})
}
}
impl Module for Qwen3_5MLP {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
let rhs = self.up_proj.forward(x)?;
self.down_proj.forward(&(lhs * rhs)?)
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

View File

@@ -0,0 +1,397 @@
//! Qwen3-Next (`model_type = "qwen3_5"`) architecture — Qwen3.6's
//! upstream architecture revision.
//!
//! ## Naming
//!
//! The model release this targets is `Qwen/Qwen3.6-*` but the
//! architecture name in HuggingFace's `config.json` is `qwen3_5`.
//! mistralrs calls the same architecture `qwen3_next`; that label
//! ages poorly the next time Qwen ship a new arch, so we key on the
//! canonical `qwen3_5` from the model's own config.
//!
//! ## Status
//!
//! **Single-GPU dense path is real**. Both attention flavours
//! (`full_attention` with the output-gated GQA causal attention and
//! `linear_attention` with the Gated DeltaNet recurrent block) are
//! implemented. The model loads from upstream safetensors via the
//! existing `load_arch_dense` dispatch and runs forward end to end.
//!
//! Numerical correctness vs the reference Python is **not yet
//! validated** — the structural code path is right, weight tensor
//! names match the upstream layout, shapes flow through cleanly, but
//! the Tbilisi probe (and any other downstream test) is the next
//! step. Likely places a bug would surface:
//! - Per-rank vs per-token-position offsets in the recurrent delta
//! rule (`linear_attn.rs`).
//! - Off-by-one in the conv state continuation across decode steps.
//! - RoPE phase mismatch from MRoPE simplification (we treat the
//! three position grids as collapsed, which is correct only for
//! text-only inference).
//!
//! ## Submodules
//!
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
//! `l2norm` helper.
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
//! rotate-half).
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
//! widening on `q_proj`.
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
//! delta rule → RMSNormGated → out_proj).
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
//! two attention flavours per layer index.
//!
//! ## Open work
//!
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
//! Sharding strategy diverges by layer type:
//! - Full-attention layers: column-parallel q/k/v (including the
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
//! `tp_qwen3.rs`.
//! - Linear-attention layers: the recurrent state is per-V-head, so
//! V-head-dimension sharding works cleanly — split `num_v_heads`
//! across ranks (`num_v_heads / world_size` per rank), shard
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
//! `dt_bias` per-head params shard with the heads.
//!
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
//! per-token recurrent path for prefill too — correct but O(L).
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
//! prefill substantially with no surface change.
use anyhow::{Context, Result};
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::Embedding;
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use serde::Deserialize;
use std::sync::Arc;
pub mod decoder;
pub mod full_attn;
pub mod linear_attn;
pub mod mlp;
pub mod rmsnorm;
pub mod rope;
use decoder::Qwen3_5DecoderLayer;
use rmsnorm::Qwen3_5RmsNorm;
use rope::RotaryEmbedding;
/// `model_type` we deserialise from `config.json`. Const so the
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
/// magic strings.
pub const MODEL_TYPE: &str = "qwen3_5";
/// Top-level shape of Qwen3-Next's `config.json`. The real
/// hyperparameters live in `text_config`; the rest is multimodal /
/// tokeniser glue we don't need for the language-model forward.
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
/// Always `"qwen3_5"` for this architecture. Kept on the struct
/// so the (eventual) dispatch / logging code can show it without
/// re-parsing the JSON.
pub model_type: String,
/// The text-side hyperparameters. Everything we actually need.
pub text_config: TextConfig,
}
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
/// but with the extras Qwen3-Next adds (`attn_output_gate`,
/// `layer_types`, `full_attention_interval`, larger `head_dim`).
#[derive(Debug, Clone, Deserialize)]
pub struct TextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub max_position_embeddings: usize,
/// Nested RoPE settings. Qwen3-Next puts `rope_theta` and
/// `partial_rotary_factor` inside this block rather than at the
/// top level — important because the partial rotary means only
/// `head_dim * partial_rotary_factor` dims get RoPE applied (the
/// rest pass through unchanged).
pub rope_parameters: RopeParameters,
pub rms_norm_eps: f64,
#[serde(default)]
pub tie_word_embeddings: bool,
/// New in Qwen3-Next: a sigmoid gate multiplied into the attention
/// output before the o_proj. The Python reference applies it
/// pointwise after softmax+matmul.
#[serde(default)]
pub attn_output_gate: bool,
/// One entry per decoder layer; values are `"full_attention"` or
/// `"linear_attention"`. Length must equal `num_hidden_layers`.
/// `full_attention_interval` is a derived hint (every 4th layer
/// by default) — `layer_types` is authoritative.
#[serde(default)]
pub layer_types: Vec<String>,
/// Hint for the layer-type pattern (defaults to 4). Kept for
/// logging / validation; the forward dispatches on `layer_types`.
#[serde(default)]
pub full_attention_interval: Option<usize>,
/// Hidden activation (`"silu"` for Qwen3-Next). Used by the MLP
/// and the linear-attention conv1d.
#[serde(default = "default_hidden_act")]
pub hidden_act: String,
// --- Gated DeltaNet (linear-attention) hyperparams -----------------
/// Per-layer linear-attention V-head count (Qwen3.6-27B: 48).
/// More V-heads than K-heads is fine — query/key get
/// `repeat_interleave`'d to match before the delta rule.
#[serde(default)]
pub linear_num_value_heads: usize,
/// Per-layer linear-attention K-head count (Qwen3.6-27B: 16).
#[serde(default)]
pub linear_num_key_heads: usize,
/// Per-head key dimension for the linear-attention path
/// (Qwen3.6-27B: 128). Separate from `head_dim` which the
/// full-attention layers use.
#[serde(default)]
pub linear_key_head_dim: usize,
/// Per-head value dimension for the linear-attention path
/// (Qwen3.6-27B: 128).
#[serde(default)]
pub linear_value_head_dim: usize,
/// Causal Conv1d kernel size used before the delta rule
/// (Qwen3.6-27B: 4).
#[serde(default)]
pub linear_conv_kernel_dim: usize,
}
fn default_hidden_act() -> String {
"silu".into()
}
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
/// `mrope_section` and `mrope_interleaved` are accepted via the
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat
/// MRoPE as plain RoPE for text-only inference (the three position
/// grids carry identical ids when there's no vision input, so the
/// interleaving is a no-op).
#[derive(Debug, Clone, Deserialize)]
pub struct RopeParameters {
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
#[serde(default = "default_rope_theta")]
pub rope_theta: f64,
/// Fraction of `head_dim` that gets the rotation applied. The
/// remaining `head_dim * (1 - partial_rotary_factor)` dims pass
/// through unchanged. Qwen3.6 / Qwen3.5: 0.25.
#[serde(default = "default_partial_rotary_factor")]
pub partial_rotary_factor: f32,
/// `"default"` for the standard inv_freq RoPE; other values (e.g.
/// `"linear"`, `"dynamic"`) are upstream-supported but not yet
/// implemented here.
#[serde(default)]
pub rope_type: Option<String>,
}
fn default_rope_theta() -> f64 {
10_000.0
}
fn default_partial_rotary_factor() -> f32 {
1.0
}
/// Qwen3-Next base transformer (embedding + decoder stack + final
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
/// loaded handle.
pub struct Qwen3_5Model {
embed_tokens: Embedding,
layers: Vec<Qwen3_5DecoderLayer>,
norm: Qwen3_5RmsNorm,
device: Device,
dtype: DType,
}
impl Qwen3_5Model {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
// Qwen3-Next is a multimodal architecture whose text core lives
// under `model.language_model.*` — sibling to `model.visual.*`
// (the vision tower) and to top-level `lm_head` / `mtp.*`.
// Every text-side tensor in the safetensors files is under
// this prefix; we ignore the vision and MTP weights for
// language-model inference.
let text_vb = vb.pp("model.language_model");
let embed_vb = text_vb.pp("embed_tokens");
let embed_weight = embed_vb
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
if cfg.layer_types.len() != cfg.num_hidden_layers {
anyhow::bail!(
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
got {}",
cfg.num_hidden_layers,
cfg.layer_types.len()
);
}
let vb_l = text_vb.pp("layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(Qwen3_5DecoderLayer::load(
cfg,
rotary.clone(),
i,
&vb_l.pp(i),
)?);
}
let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
pub fn embed_weight(&self) -> &Tensor {
self.embed_tokens.embeddings()
}
pub fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
// Causal mask only needed for L > 1 prefill; full-attention
// layers consume it via broadcast_add. Linear-attention layers
// ignore the mask.
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
pub struct Qwen3_5ForCausalLM {
base: Qwen3_5Model,
lm_head: Linear,
}
impl Qwen3_5ForCausalLM {
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
let cfg = &config.text_config;
let base = Qwen3_5Model::load(cfg, &vb)?;
let lm_head = if cfg.tie_word_embeddings {
Linear::new(base.embed_weight().clone(), None)
} else {
let weight = vb
.pp("lm_head")
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
Linear::new(weight, None)
};
Ok(Self { base, lm_head })
}
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
/// the last position, shape `(B, 1, vocab_size)` — same contract
/// as `qwen3::ModelForCausalLM::forward` so the harness's
/// `squeeze_to_vocab` helper handles both uniformly.
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self.base.forward(input, offset)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Confirms we can deserialise the real upstream config shape.
/// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to
/// the fields the architecture cares about. Note `rope_theta` and
/// `partial_rotary_factor` are nested under `rope_parameters` —
/// Qwen3-Next does NOT have a top-level `rope_theta`.
#[test]
fn config_deserialises_the_real_qwen3_6_shape() {
let raw = r#"{
"architectures": ["Qwen3_5ForConditionalGeneration"],
"model_type": "qwen3_5",
"image_token_id": 248056,
"language_model_only": false,
"text_config": {
"vocab_size": 248064,
"hidden_size": 5120,
"intermediate_size": 17408,
"num_hidden_layers": 64,
"num_attention_heads": 64,
"num_key_value_heads": 8,
"head_dim": 256,
"max_position_embeddings": 32768,
"rope_parameters": {
"mrope_interleaved": true,
"mrope_section": [11, 11, 10],
"partial_rotary_factor": 0.25,
"rope_theta": 10000000,
"rope_type": "default"
},
"rms_norm_eps": 1e-6,
"tie_word_embeddings": false,
"attn_output_gate": true,
"full_attention_interval": 4,
"layer_types": [
"linear_attention", "linear_attention",
"linear_attention", "full_attention"
]
}
}"#;
let cfg: Config = serde_json::from_str(raw).expect("parse Qwen3.6 config");
assert_eq!(cfg.model_type, "qwen3_5");
assert_eq!(cfg.text_config.hidden_size, 5120);
assert_eq!(cfg.text_config.head_dim, 256);
assert!(cfg.text_config.attn_output_gate);
assert_eq!(cfg.text_config.full_attention_interval, Some(4));
assert_eq!(cfg.text_config.layer_types.len(), 4);
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
}
}

View File

@@ -0,0 +1,161 @@
//! Norm primitives for Qwen3-Next.
//!
//! Two reasons we can't reuse `candle_nn::RmsNorm` directly:
//!
//! 1. **`(1.0 + weight)` scaling.** Qwen3-Next's `Qwen3_5RMSNorm`
//! initialises `weight` to zeros and applies `(1.0 + weight)` to
//! the normalised vector. `candle_nn::RmsNorm` applies `weight`
//! directly. The two are equivalent only when the operator has
//! pre-shifted the weights — the upstream checkpoints have not. See
//! `huggingface/transformers#29402` for the upstream PR that
//! introduced the `(1 + w)` form to recover from the zero-init.
//!
//! 2. **Gated variant.** The linear-attention layer post-normalises
//! its output by an RMSNorm *gated* with a per-element SiLU on
//! a sibling `z` projection — fused for numerical reasons (the
//! norm's float32 promotion has to happen before the SiLU
//! multiply). Not a single existing candle op.
//!
//! Both ops accept inputs in any compute dtype; promotion to f32 for
//! the variance calculation matches the Python reference.
use anyhow::{Context, Result};
use candle_core::{D, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
/// L2-normalise along the last dim with a small epsilon. Matches the
/// `l2norm` helper in `transformers/models/qwen3_5/modeling_qwen3_5.py`
/// — `x * rsqrt(sum(x*x) + eps)`. The linear-attention path uses this
/// on Q and K before the delta rule when
/// `use_qk_l2norm_in_kernel=True` (which Qwen3-Next always sets).
pub fn l2norm(x: &Tensor, eps: f32) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let sq = x_f32.sqr()?;
let sum = sq.sum_keepdim(D::Minus1)?;
let inv = (sum + eps as f64)?.sqrt()?.recip()?;
x_f32.broadcast_mul(&inv)?.to_dtype(dtype)
}
/// Qwen3-Next's RMSNorm. Stores the raw weight tensor; forward applies
/// `(1.0 + weight) * x_normed`.
pub struct Qwen3_5RmsNorm {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNorm {
/// Load `weight` from the ShardedVarBuilder. `vb` should already be
/// `.pp(...)`-ed to the norm's tensor prefix.
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
pub fn size(&self) -> usize {
self.size
}
}
impl Module for Qwen3_5RmsNorm {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
// Promote weight to f32 and shift by 1.0 *before* multiplying.
// Doing the (1 + w) operation in fp16 lands at -inf for the
// bottom-of-range weights at load time.
let w_f32 = self.weight.to_dtype(candle_core::DType::F32)?;
let scale = (w_f32 + 1.0_f64)?;
normed.broadcast_mul(&scale)?.to_dtype(dtype)
}
}
/// Gated RMSNorm used at the tail of `Qwen3_5GatedDeltaNet`. Equivalent
/// to `x_normed * weight * silu(gate)` but with both the norm and the
/// gate evaluated in float32 to avoid mid-pipeline underflow.
///
/// Note: unlike `Qwen3_5RmsNorm`, this variant matches the Python
/// reference's `Qwen3_5RMSNormGated` which uses `weight` directly (not
/// `1.0 + weight`).
pub struct Qwen3_5RmsNormGated {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNormGated {
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
/// Direct constructor — used by unit tests that build a layer
/// without going through a VarBuilder.
#[cfg(test)]
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
let size = weight.dims()[0];
Self {
weight,
eps: eps as f32,
size,
}
}
pub fn size(&self) -> usize {
self.size
}
/// `x` and `gate` share the same last-dim shape (`size`).
pub fn forward(&self, x: &Tensor, gate: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
let w = self.weight.to_dtype(candle_core::DType::F32)?;
let out = normed.broadcast_mul(&w)?;
// SiLU on the float32 gate, multiply back into the normed
// tensor, then cast to the model dtype.
let g = gate.to_dtype(candle_core::DType::F32)?;
let silu_gate = candle_nn::ops::silu(&g)?;
(out * silu_gate)?.to_dtype(dtype)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn l2norm_matches_hand_calc() {
let x = Tensor::new(&[3.0_f32, 4.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
// |x| = 5, so x/|x| = [0.6, 0.8] (eps is tiny).
assert!((v[0] - 0.6).abs() < 1e-4);
assert!((v[1] - 0.8).abs() < 1e-4);
}
#[test]
fn l2norm_zero_vector_is_safe_via_epsilon() {
let x = Tensor::new(&[0.0_f32, 0.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
}

View File

@@ -0,0 +1,114 @@
//! Rotary position embedding for Qwen3-Next's full-attention layers.
//!
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
//! reference Python — three position grids interleaved per
//! `mrope_section`. For text-only inference all three grids carry the
//! same position ids and the interleave is a no-op, so this module
//! implements the plain (non-mrope) flavour: the standard inv_freq
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
//!
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
//! head dim is negated and swapped into the first). The reference
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
//! `rope_slow` is the matching helper.
use anyhow::Result;
use candle_core::{DType, Device, Tensor};
use super::TextConfig;
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
/// Number of dims at the head's leading edge that the rotation
/// covers. The remaining `head_dim - rotary_dim` dims pass through
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
/// for `head_dim = 256` only 64 dims rotate.
rotary_dim: usize,
head_dim: usize,
}
impl RotaryEmbedding {
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
let head_dim = cfg.head_dim;
let rope = &cfg.rope_parameters;
let rotary_dim = (head_dim as f32 * rope.partial_rotary_factor) as usize;
if !rotary_dim.is_multiple_of(2) {
anyhow::bail!(
"rotary_dim = head_dim * partial_rotary_factor = {head_dim} * {} = {rotary_dim} \
must be even (cos/sin are paired)",
rope.partial_rotary_factor
);
}
if rotary_dim == 0 {
anyhow::bail!(
"rotary_dim = 0 (partial_rotary_factor = {} too small)",
rope.partial_rotary_factor
);
}
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<f32> = (0..rotary_dim)
.step_by(2)
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
.collect();
let n = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
rotary_dim,
head_dim,
})
}
/// Apply RoPE to q, k.
///
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
/// into the cached cos/sin table — the position of the first token
/// in the current step.
///
/// When `rotary_dim < head_dim` the rotation is applied only to the
/// first `rotary_dim` dims of each head; the tail passes through
/// unchanged (matches the reference Python's
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`).
pub fn apply(
&self,
q: &Tensor,
k: &Tensor,
offset: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, head_dim_in) = q.dims4()?;
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
if self.rotary_dim == self.head_dim {
// Full rotation.
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
} else {
// Partial rotation: narrow → rotate → cat the untouched tail.
let tail = self.head_dim - self.rotary_dim;
let q_rot = q
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
.contiguous()?;
let q_pass = q.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
let k_rot = k
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
.contiguous()?;
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?;
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?;
let q_embed =
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
let k_embed =
Tensor::cat(&[&k_rotated, &k_pass.contiguous()?], candle_core::D::Minus1)?;
Ok((q_embed, k_embed))
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,8 @@
//! Harness registry — maps harness names to trait implementations. //! Harness registry — maps harness names to trait implementations.
pub mod arch;
pub mod candle; pub mod candle;
pub mod tp;
use anyhow::Result; use anyhow::Result;
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec}; use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};

View File

@@ -0,0 +1,119 @@
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
//!
//! Ported from the canonical
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
//! Row-parallel layers apply this op after their local matmul to sum
//! partial outputs across NCCL ranks.
//!
//! Available only under `--features cuda`; on CPU builds this module
//! is empty and row-parallel layers degenerate to local matmul only
//! (useful for compile-checking the model code; correctness requires
//! cuda).
//!
//! Thread-safety caveat: NCCL communicators are technically only
//! safe to use from a single thread at a time
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
//! against it from the dedicated `spawn_blocking` thread the inference
//! pipeline already uses for candle's forward passes.
#![cfg(feature = "cuda")]
use candle_core::backend::BackendStorage;
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
use cudarc::nccl::{Comm, ReduceOp};
use half::{bf16, f16};
use std::sync::Arc;
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
/// graph as a custom op. Each row-parallel layer holds one of these.
pub struct AllReduce {
comm: Arc<Comm>,
}
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
// that issuing ops against one comm from multiple threads concurrently
// is unsafe. We serialise via the single spawn_blocking thread that
// drives the model's forward pass. The Send/Sync impl is necessary
// because candle's CustomOp1 trait bounds require it; the correctness
// invariant is enforced at the call site, not the type level.
unsafe impl Send for AllReduce {}
unsafe impl Sync for AllReduce {}
impl AllReduce {
pub fn new(comm: Arc<Comm>) -> Self {
Self { comm }
}
pub fn comm(&self) -> &Arc<Comm> {
&self.comm
}
}
impl CustomOp1 for AllReduce {
fn name(&self) -> &'static str {
"neuron.tp.all_reduce"
}
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
}
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
// Reject non-contiguous inputs explicitly — copying them
// server-side would mask shape bugs (a TP layer feeding a
// strided activation into all_reduce is almost certainly a
// model construction error).
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
slice: &cudarc::driver::CudaSlice<T>,
l: &Layout,
) -> Result<()> {
match l.contiguous_offsets() {
Some((0, n)) if n == slice.len() => Ok(()),
_ => candle_core::bail!(
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
l,
slice.len()
),
}
}
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let out = match s.dtype() {
DType::BF16 => {
let src = s.as_cuda_slice::<bf16>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F16 => {
let src = s.as_cuda_slice::<f16>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F32 => {
let src = s.as_cuda_slice::<f32>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
dtype => candle_core::bail!(
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
),
};
Ok((out, l.shape().clone()))
}
}

View File

@@ -0,0 +1,213 @@
//! Direct safetensors readers for fused-region weight tensors.
//!
//! Qwen3-Next's `in_proj_qkv` and `conv1d` weights are *fused* —
//! three regions stored sequentially along dim 0 (`[key_q, key_k,
//! value]`). The per-rank shard for each region has unequal size
//! (`key_dim/ws` vs `value_dim/ws`), so candle's `ShardedSafeTensors`
//! built-in `Shard { dim, rank, world_size }` (uniform split) doesn't
//! map to the right slices.
//!
//! The previous approach loaded the full fused tensor onto the device,
//! `narrow`ed the three regions, and `Tensor::cat(...).contiguous()`'d
//! the per-rank slice. That left ~100 MB of transient device memory
//! per linear-attention layer — 48 layers × 100 MB = ~4.8 GB of
//! allocator pressure during load, enough to trigger fragmentation
//! OOM on tight-VRAM consumer GPUs.
//!
//! This module reads the three per-rank byte ranges *directly from
//! the safetensors mmap* (host-side), concatenates them into a single
//! contiguous byte buffer, and uploads as one device allocation. No
//! full-tensor device materialisation.
use anyhow::{Context, Result, bail};
use candle_core::safetensors::MmapedSafetensors;
use candle_core::{DType, Device, Tensor};
/// Read a 2D fused-QKV tensor `[conv_dim, hidden_size]` and return
/// this rank's per-region slice as a `[per_rank_conv_dim, hidden_size]`
/// device tensor.
///
/// `tensor_name` must be the fully-qualified safetensors key (e.g.
/// `"model.language_model.layers.5.linear_attn.in_proj_qkv.weight"`).
#[allow(clippy::too_many_arguments)]
pub fn load_fused_qkv_2d(
mmap: &MmapedSafetensors,
tensor_name: &str,
hidden_size: usize,
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
target_dtype: DType,
device: &Device,
) -> Result<Tensor> {
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
let view = mmap
.get(tensor_name)
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 2D"))?;
let view_dtype: DType = view
.dtype()
.try_into()
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
let shape = view.shape();
if shape.len() != 2 {
bail!(
"fused qkv tensor '{tensor_name}' has shape {shape:?}, expected 2D \
[conv_dim, hidden_size]"
);
}
let conv_dim = key_dim * 2 + value_dim;
if shape[0] != conv_dim || shape[1] != hidden_size {
bail!(
"fused qkv tensor '{tensor_name}' shape {shape:?} \
doesn't match expected [{conv_dim}, {hidden_size}]"
);
}
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
let k_bytes = slice_dim0_bytes(
&view,
key_dim + r * per_rank_key,
per_rank_key,
tensor_name,
"k",
)?;
let v_bytes = slice_dim0_bytes(
&view,
2 * key_dim + r * per_rank_value,
per_rank_value,
tensor_name,
"v",
)?;
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
bytes.extend_from_slice(&q_bytes);
bytes.extend_from_slice(&k_bytes);
bytes.extend_from_slice(&v_bytes);
let tensor = Tensor::from_raw_buffer(
&bytes,
view_dtype,
&[per_rank_conv_dim, hidden_size],
device,
)
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused qkv '{tensor_name}'"))?;
tensor
.to_dtype(target_dtype)
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
}
/// Read a 3D fused-QKV tensor `[conv_dim, 1, kernel_size]` (the
/// depthwise conv1d weight) and return this rank's per-region slice
/// as a `[per_rank_conv_dim, 1, kernel_size]` device tensor.
#[allow(clippy::too_many_arguments)]
pub fn load_fused_qkv_3d(
mmap: &MmapedSafetensors,
tensor_name: &str,
mid: usize,
kernel_size: usize,
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
target_dtype: DType,
device: &Device,
) -> Result<Tensor> {
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
let view = mmap
.get(tensor_name)
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 3D"))?;
let view_dtype: DType = view
.dtype()
.try_into()
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
let shape = view.shape();
if shape.len() != 3 {
bail!(
"fused conv tensor '{tensor_name}' has shape {shape:?}, expected 3D \
[conv_dim, mid, kernel_size]"
);
}
let conv_dim = key_dim * 2 + value_dim;
if shape[0] != conv_dim || shape[1] != mid || shape[2] != kernel_size {
bail!(
"fused conv tensor '{tensor_name}' shape {shape:?} \
doesn't match expected [{conv_dim}, {mid}, {kernel_size}]"
);
}
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
let k_bytes = slice_dim0_bytes(
&view,
key_dim + r * per_rank_key,
per_rank_key,
tensor_name,
"k",
)?;
let v_bytes = slice_dim0_bytes(
&view,
2 * key_dim + r * per_rank_value,
per_rank_value,
tensor_name,
"v",
)?;
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
bytes.extend_from_slice(&q_bytes);
bytes.extend_from_slice(&k_bytes);
bytes.extend_from_slice(&v_bytes);
let tensor = Tensor::from_raw_buffer(
&bytes,
view_dtype,
&[per_rank_conv_dim, mid, kernel_size],
device,
)
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused conv '{tensor_name}'"))?;
tensor
.to_dtype(target_dtype)
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
}
/// Read `len` consecutive rows along dim 0 starting at `start` from
/// the safetensors view, returning the raw bytes. Wraps the same
/// `view.slice(start..stop)` machinery that candle's
/// `ShardedSafeTensors::get` uses internally.
fn slice_dim0_bytes(
view: &safetensors::tensor::TensorView<'_>,
start: usize,
len: usize,
tensor_name: &str,
region: &str,
) -> Result<Vec<u8>> {
use safetensors::slice::IndexOp;
let stop = start + len;
let iter = view.slice(start..stop).map_err(|e| {
anyhow::anyhow!("slice '{tensor_name}' region {region} ({start}..{stop}): {e:?}")
})?;
Ok(iter.into_iter().flatten().copied().collect())
}

View File

@@ -0,0 +1,791 @@
//! Tensor-parallel inference plumbing.
//!
//! The leader process (the neuron daemon proper) drives one
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
//! and talks to each over a newline-delimited JSON RPC channel on
//! the worker's stdin/stdout (see `rpc.rs`).
//!
//! Sub-staging:
//!
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
//! forks N workers; `ping` round-trips every worker to confirm
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
//! `NcclSanityCheck` are stubbed.
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
//! `NcclSanityCheck`. CUDA-gated.
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
//! - **7c:** crash detection, streaming SSE, graceful unload.
pub mod all_reduce;
pub mod fused_load;
pub mod nccl_state;
pub mod rpc;
pub mod tp_linear;
pub mod tp_qwen3;
pub mod tp_qwen3_5;
pub mod worker;
use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use rpc::{WorkerRequest, WorkerResponse};
/// Leader-side handle for any TP-loaded model. The pool's
/// `load_dense_shard` dispatches on `config.json#/model_type` to build
/// the right variant; downstream callers (the harness's
/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`,
/// `unload_model`) all hold this enum and let the variant dispatch
/// determine the concrete forward.
///
/// Variants gated on `cuda` because the underlying TP models hold
/// `Arc<cudarc::nccl::Comm>` references — irrelevant on CPU builds.
#[cfg(feature = "cuda")]
pub enum TpLeaderModel {
Qwen3(tp_qwen3::TpQwen3ForCausalLM),
Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM),
}
#[cfg(feature = "cuda")]
impl TpLeaderModel {
pub fn forward(
&mut self,
input: &candle_core::Tensor,
offset: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
TpLeaderModel::Qwen3(m) => m.forward(input, offset),
TpLeaderModel::Qwen3_5(m) => m.forward(input, offset),
}
}
pub fn clear_kv_cache(&mut self) {
match self {
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(),
}
}
pub fn device(&self) -> &candle_core::Device {
match self {
TpLeaderModel::Qwen3(m) => m.device(),
TpLeaderModel::Qwen3_5(m) => m.device(),
}
}
}
/// One worker subprocess plus its bidirectional stdio handles.
struct Worker {
rank: u32,
/// Captured so the leader can log "spawned rank N on device M" and
/// future stages can re-issue Init after a CUDA reset. Unused in
/// the Stage 7a-i RPC paths themselves.
#[allow(dead_code)]
cuda_device: u32,
child: Child,
stdin: ChildStdin,
stdout: Lines<BufReader<ChildStdout>>,
}
impl Worker {
/// Send a request and wait for the response. Used for sequenced
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
/// overlap the worker's execution with the leader's.
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
self.send_only(req).await?;
self.recv_only().await
}
/// Write a request without awaiting its response. Pair with
/// `recv_only` from the caller when leader and worker need to do
/// work concurrently — e.g. during `Init`, where the leader
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
/// workers, then collects `InitOk` after NCCL completes.
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
line.push('\n');
self.stdin
.write_all(line.as_bytes())
.await
.with_context(|| format!("write request to rank {}", self.rank))?;
self.stdin
.flush()
.await
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
Ok(())
}
async fn recv_only(&mut self) -> Result<WorkerResponse> {
let reply = self
.stdout
.next_line()
.await
.with_context(|| format!("read reply from rank {}", self.rank))?
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
serde_json::from_str(&reply)
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
}
}
/// Drain one response from every worker, classifying each via the
/// supplied checker. Always reads from every worker — even if some
/// fail — so the next call's recv doesn't pick up stale responses
/// from this one (pipe-poisoning was the cause of the
/// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class
/// of bugs).
///
/// Returns a vector of `rank N: detail` strings for any worker that
/// errored, expected-mismatched, or failed to respond. Caller decides
/// how to combine these with the leader's outcome.
async fn drain_workers(
workers: &mut [Worker],
mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>,
) -> Vec<String> {
let mut errs = Vec::new();
for w in workers {
match w.recv_only().await {
Ok(resp) => {
if let Err(detail) = check(resp) {
errs.push(format!("rank {} {detail}", w.rank));
}
}
Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)),
}
}
errs
}
/// Combine a leader's `Result<Result<T>>` (the typical
/// `spawn_blocking → JoinHandle<Result<T>>` shape) with the worker
/// drain results into a single `Result<T>`. Leader failures take
/// precedence in the error message but worker errors get appended so
/// the operator sees both halves.
#[cfg(feature = "cuda")]
fn combine_leader_workers<T>(
leader: Result<Result<T>>,
worker_errors: Vec<String>,
op: &str,
) -> Result<T> {
match leader {
Ok(Ok(value)) => {
if worker_errors.is_empty() {
Ok(value)
} else {
anyhow::bail!(
"{op}: leader succeeded but workers failed: {}",
worker_errors.join("; ")
)
}
}
Ok(Err(e)) => {
if worker_errors.is_empty() {
Err(e.context(format!("{op}: leader forward failed")))
} else {
Err(e.context(format!(
"{op}: leader forward failed and workers also failed: {}",
worker_errors.join("; ")
)))
}
}
Err(panic_err) => {
if worker_errors.is_empty() {
Err(panic_err)
} else {
Err(panic_err.context(format!(
"{op}: leader task panicked and workers failed: {}",
worker_errors.join("; ")
)))
}
}
}
}
/// A live pool of worker subprocesses. Owns the `Child` handles so
/// dropping the pool kills the children; explicit `shutdown()` is
/// the graceful path.
pub struct WorkerPool {
world_size: u32,
workers: Vec<Worker>,
/// Path to the neuron binary used to launch workers.
#[allow(dead_code)]
exe: PathBuf,
/// Leader's own NCCL rank-0 state. Defaults to empty; populated by
/// `init_nccl()`. Held here so the leader can participate in
/// collectives (rank 0) without spawning a fourth subprocess.
leader_nccl: nccl_state::NcclState,
}
impl WorkerPool {
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
/// leader (in-process) and is *not* spawned here — the leader
/// holds rank 0's NCCL Comm and shard in its own address space.
///
/// `binary` is the path to the neuron executable to run for each
/// worker (production passes `/proc/self/exe`; tests pass the
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
/// `cuda_devices` is one entry per rank including rank 0. Worker
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result<Self> {
if world_size < 2 {
anyhow::bail!(
"WorkerPool::spawn called with world_size={world_size}; \
use the single-process path for world_size < 2"
);
}
if cuda_devices.len() as u32 != world_size {
anyhow::bail!(
"expected {world_size} cuda_devices entries, got {}",
cuda_devices.len()
);
}
let exe = binary.to_path_buf();
let mut workers = Vec::with_capacity(world_size as usize - 1);
// Rank 0 stays in-process. Spawn ranks 1..world_size.
for rank in 1..world_size {
let cuda_device = cuda_devices[rank as usize];
let mut cmd = Command::new(&exe);
cmd.arg("--worker")
.arg("--rank")
.arg(rank.to_string())
.arg("--tp-size")
.arg(world_size.to_string())
.arg("--cuda-device")
.arg(cuda_device.to_string())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
// Inherit stderr so worker tracing surfaces alongside
// the leader's journalctl stream.
.stderr(Stdio::inherit())
.kill_on_drop(true);
let mut child = cmd
.spawn()
.with_context(|| format!("spawn worker rank {rank}"))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
let stdout = BufReader::new(stdout).lines();
workers.push(Worker {
rank,
cuda_device,
child,
stdin,
stdout,
});
tracing::info!(rank, cuda_device, "spawned tp worker");
}
Ok(Self {
world_size,
workers,
exe,
leader_nccl: nccl_state::NcclState::new(),
})
}
/// Establish the NCCL communicator across the leader (rank 0) and
/// every worker subprocess. Rendezvous is via a freshly-generated
/// `Id` broadcast over the RPC stream; the actual handshake blocks
/// inside `Comm::from_rank` until all `world_size` ranks check in.
///
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
/// to — typically the first entry of the `cuda_devices` slice
/// originally passed to `spawn()`.
///
/// On the non-cuda build this immediately fails because the leader
/// can't generate an `Id` without libnccl. The same call works in
/// the worker path (returning a no-cuda error response) so the
/// failure surface is uniform.
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
let comm_id = nccl_state::generate_comm_id_hex()
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
// 1. Write Init to every worker's stdin without awaiting the
// response. Workers will parse and call Comm::from_rank
// concurrently with the leader below.
for w in &mut self.workers {
let req = WorkerRequest::Init {
comm_id: comm_id.clone(),
};
w.send_only(&req).await?;
}
// 2. Leader rank 0 calls Comm::from_rank on its own device.
// Runs on spawn_blocking because NCCL's init blocks until
// every rank has called in — that's exactly the workers
// above. The leader's NcclState is moved through the
// blocking task and returned to the pool.
let leader_cfg = worker::WorkerConfig {
rank: 0,
world_size: self.world_size,
cuda_device: leader_cuda_device,
};
let comm_id_for_leader = comm_id.clone();
// Swap out the leader's NcclState into a fresh empty one so we
// can move it into spawn_blocking; restore after the task
// returns. (NcclState isn't Clone — it owns a real NCCL Comm.)
let mut leader_state = std::mem::take(&mut self.leader_nccl);
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
let resp = leader_state.init(leader_cfg, &comm_id_for_leader);
(leader_state, resp)
})
.await
.context("leader NCCL init task panicked")?;
self.leader_nccl = returned_state;
match leader_resp {
rpc::WorkerResponse::InitOk => {}
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
}
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
}
// 3. Read InitOk from each worker. By now every worker has
// completed its Comm::from_rank call (NCCL released them
// when the leader joined the handshake) and is writing its
// response.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match &resp {
rpc::WorkerResponse::InitOk => {}
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
}
other => anyhow::bail!(
"worker rank {} init: expected InitOk, got {other:?}",
w.rank
),
}
}
tracing::info!(
world_size = self.world_size,
"NCCL communicator established across all ranks"
);
Ok(())
}
/// Validate the NCCL communicator: every rank `all_reduce`s a
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
/// `world_size`. Confirms the handshake is live, not just
/// configured.
///
/// Must be called after `init_nccl()`; before that the leader has
/// no Comm and the workers reply with `nccl_not_initialised`.
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
// 1. Trigger the all_reduce on every worker (write-only).
for w in &mut self.workers {
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
}
// 2. Leader's own all_reduce, in spawn_blocking. NCCL operations
// block until every rank participates.
let mut leader_state = std::mem::take(&mut self.leader_nccl);
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
let resp = leader_state.sanity_check();
(leader_state, resp)
})
.await
.context("leader NCCL sanity task panicked")?;
self.leader_nccl = returned_state;
let expected = self.world_size;
let leader_sum = match leader_resp {
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
}
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
};
if leader_sum != expected {
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
}
// 3. Read sanity result from each worker. All must match
// world_size — anything else means the collective didn't
// complete consistently across ranks.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
rpc::WorkerResponse::NcclSanityResult { observed_sum }
if observed_sum == expected => {}
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
anyhow::bail!(
"worker rank {} observed_sum={observed_sum}, expected {expected}",
w.rank
);
}
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
}
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
}
}
tracing::info!(
world_size = expected,
"NCCL sanity check OK across all ranks"
);
Ok(())
}
/// Ping every worker and return their Pong payloads in rank order.
/// Useful right after `spawn` to confirm the lifecycle plumbing is
/// intact before kicking off any heavier work.
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
let mut out = Vec::with_capacity(self.workers.len());
for w in &mut self.workers {
let resp = w.request(&WorkerRequest::Ping).await?;
match &resp {
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
WorkerResponse::Pong { rank, .. } => {
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
}
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
}
out.push(resp);
}
Ok(out)
}
/// Load this rank's shard of a dense Qwen3 model on every rank.
///
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
/// shards in their own address spaces and confirm via
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
/// of the full model (plus the replicated embedding/norm/lm_head).
///
/// `leader_device` is the candle `Device` the leader's shard lives
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
/// the same index passed to `init_nccl`. `dtype` is the on-device
/// element type; bf16 is the canonical Qwen3 distribution dtype.
///
/// `init_nccl` must have completed first. Bails if the leader's
/// NCCL comm isn't set up yet.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn load_dense_shard(
&mut self,
model_id: &str,
config_json: &str,
safetensors_paths: &[std::path::PathBuf],
leader_device: &candle_core::Device,
dtype: candle_core::DType,
quant: Option<String>,
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
use candle_nn::var_builder::ShardedSafeTensors;
use std::sync::Arc;
use tokio::sync::Mutex;
// Wrap the comm in SendComm immediately so it stays Send across
// the await points in this method — bare Arc<Comm> would
// poison the async fn's Send bound (Comm's raw NCCL pointer is
// !Send). The wrapper's safety contract is satisfied by the
// pool's outer Mutex serialising callers + the spawn_blocking
// thread being the only place ops are issued.
let leader_comm =
nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| {
anyhow::anyhow!("leader NCCL not initialised; call init_nccl first")
})?);
let world_size = self.world_size;
let safetensors_str: Vec<String> = safetensors_paths
.iter()
.map(|p| p.to_string_lossy().into_owned())
.collect();
// 1. Fan out the LoadDenseShard request to every worker without
// awaiting their replies — they'll build their shards in
// parallel with the leader below.
for w in &mut self.workers {
w.send_only(&WorkerRequest::LoadDenseShard {
model_id: model_id.to_string(),
config_json: config_json.to_string(),
safetensors_paths: safetensors_str.clone(),
quant: quant.clone(),
})
.await?;
}
// 2. Build rank 0's shard on the leader. Dispatch on model_type
// — for `qwen3` we build a `TpQwen3ForCausalLM`, for
// `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build
// `TpQwen3_5ForCausalLM`. Both end up wrapped in the
// `TpLeaderModel` enum so downstream callers don't care.
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
let device_for_leader = leader_device.clone();
let comm_for_leader = leader_comm;
let model_id_for_log = model_id.to_string();
let config_json_for_leader = config_json.to_string();
let quant_for_leader = quant.clone();
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
// SAFETY: same invariant as the single-GPU dense path —
// the HF cache files are treated as immutable while the
// mmap is held.
let vb = unsafe {
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
.context("build ShardedVarBuilder over safetensors")?
};
// SAFETY: as above — the HF cache files are immutable.
let mmap = unsafe {
candle_core::safetensors::MmapedSafetensors::multi(&paths_for_leader)
.context("build MmapedSafetensors for leader load")?
};
let comm = comm_for_leader.into_inner();
let loaded = match model_type.as_str() {
"qwen3" => {
let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3 Config JSON for leader load")?;
TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
&cfg, &vb, 0, world_size, comm,
)?)
}
"qwen3_5" => {
let cfg: super::tp::tp_qwen3_5::Config =
serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3-Next Config JSON for leader load")?;
let quant_dtype =
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
cfg,
&vb,
&mmap,
0,
world_size,
comm,
quant_dtype,
)?)
}
other => anyhow::bail!(
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
),
};
tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)");
Ok(loaded)
})
.await
.context("leader load task panicked")??;
// 3. Collect worker confirmations. Anything other than
// LoadDenseShardOk aborts the whole load — the leader's
// already-loaded shard drops when this fn returns Err.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
WorkerResponse::LoadDenseShardOk => {}
WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
}
other => anyhow::bail!(
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
w.rank
),
}
}
Ok(Arc::new(Mutex::new(leader_model)))
}
/// Run one forward step across every rank. The leader's forward
/// returns the last-position logits as a candle Tensor on the
/// leader's device; the caller does sampling out-of-band. Workers
/// run their own forwards (the AllReduce inside row-parallel layers
/// is what lets the leader's collective complete) and reply with
/// `GenerateStepOk` — they do not ship logits over the wire.
///
/// `tokens` is the input for this step (prompt for prefill, the
/// previously-sampled token for decode). `offset` is the KV-cache
/// position before this step.
#[cfg(feature = "cuda")]
pub async fn generate_step(
&mut self,
model_id: &str,
leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
tokens: Vec<u32>,
offset: usize,
) -> Result<candle_core::Tensor> {
let step_start = std::time::Instant::now();
let tokens_len = tokens.len();
tracing::debug!(
model = %model_id,
tokens = tokens_len,
offset,
"WorkerPool::generate_step: fan-out"
);
// 1. Fan-out to workers.
for w in &mut self.workers {
w.send_only(&WorkerRequest::GenerateStep {
model_id: model_id.to_string(),
tokens: tokens.clone(),
offset,
})
.await?;
}
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
// inside the row-parallel layers block until every worker's
// forward issues the matching collective.
let leader_start = std::time::Instant::now();
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
let mut model = leader_model.blocking_lock();
let device = model.device().clone();
let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
// ForCausalLM::forward returns [B, 1, V] — squeeze both
// leading dims to the rank-1 vocab logits the sampler wants.
let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?;
Ok(logits)
})
.await
.context("leader forward task panicked");
let leader_ok = matches!(leader_result, Ok(Ok(_)));
tracing::debug!(
model = %model_id,
tokens = tokens_len,
leader_ms = leader_start.elapsed().as_millis(),
leader_ok,
"WorkerPool::generate_step: leader forward returned"
);
// 3. ALWAYS drain worker responses, regardless of whether the
// leader succeeded. Skipping this on the leader's error
// path leaves stale GenerateStepOk replies in the worker
// pipes that poison the NEXT request's recv (was seeing
// "ClearKvCache: expected KvCacheCleared, got
// GenerateStepOk" the call after any forward-time failure).
let drain_start = std::time::Instant::now();
let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::GenerateStepOk => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
other => Err(format!("expected GenerateStepOk, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
drain_ms = drain_start.elapsed().as_millis(),
errors = worker_errors.len(),
total_ms = step_start.elapsed().as_millis(),
"WorkerPool::generate_step: workers drained"
);
combine_leader_workers(leader_result, worker_errors, "GenerateStep")
}
/// Reset the KV cache for `model_id` on every rank. Called at the
/// start of every inference so a fresh request doesn't attend over
/// the previous one's tokens.
pub async fn clear_kv_cache(
&mut self,
model_id: &str,
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
) -> Result<()> {
let start = std::time::Instant::now();
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
for w in &mut self.workers {
w.send_only(&WorkerRequest::ClearKvCache {
model_id: model_id.to_string(),
})
.await?;
}
#[cfg(feature = "cuda")]
{
let mut m = leader_model.lock().await;
m.clear_kv_cache();
}
// Drain workers — same rationale as `generate_step`. The
// leader's clear_kv_cache is in-process and infallible, but we
// still always drain so an error on one worker doesn't leave
// pending responses for the others.
let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::KvCacheCleared => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
other => Err(format!("expected KvCacheCleared, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
errors = worker_errors.len(),
"WorkerPool::clear_kv_cache: workers drained"
);
if !worker_errors.is_empty() {
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
}
Ok(())
}
/// Drop this model's shards on every rank. The leader's shard is
/// expected to have been dropped by the caller (its `Arc` was held
/// in the TpLoadedModel and goes away when that's removed).
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
for w in &mut self.workers {
w.send_only(&WorkerRequest::UnloadModel {
model_id: model_id.to_string(),
})
.await?;
}
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
WorkerResponse::Unloaded => {}
WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
}
other => anyhow::bail!(
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
w.rank
),
}
}
Ok(())
}
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
/// children. Best-effort — individual worker failures are logged
/// but don't abort the rest of the sweep.
pub async fn shutdown(mut self) -> Result<()> {
for w in &mut self.workers {
match w.request(&WorkerRequest::Shutdown).await {
Ok(WorkerResponse::Bye) => {}
Ok(other) => tracing::warn!(
rank = w.rank,
response = ?other,
"expected Bye on shutdown"
),
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
}
}
for w in &mut self.workers {
match w.child.wait().await {
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
}
}
Ok(())
}
pub fn world_size(&self) -> u32 {
self.world_size
}
pub fn binary_path(&self) -> &PathBuf {
&self.exe
}
}

View File

@@ -0,0 +1,293 @@
//! NCCL state held by both the worker process and the leader's pool.
//!
//! Split into its own module so the worker (`tp/worker.rs`) and the
//! leader (`tp/mod.rs`) share the same hex-encoding/decoding code and
//! the same shape of `Option<Comm>` state machine.
//!
//! When the `cuda` feature is off, `NcclState` is a zero-sized
//! placeholder that returns `Error{kind="cuda_feature_not_enabled"}`
//! from every operation. When it's on, the same struct holds the
//! actual `cudarc::nccl::Comm`.
use super::rpc::WorkerResponse;
use super::worker::WorkerConfig;
/// Encode bytes as lowercase hex. Used for ferrying NCCL `Id::internal()`
/// across the leader→worker RPC boundary inside a JSON string.
pub fn encode_hex(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len() * 2);
for b in bytes {
use std::fmt::Write;
let _ = write!(out, "{b:02x}");
}
out
}
/// Decode lowercase-or-uppercase hex into bytes. Errors on odd length
/// or non-hex characters; the caller bubbles those up via the RPC's
/// `Error{kind="bad_request"}` variant.
pub fn decode_hex(s: &str) -> Result<Vec<u8>, String> {
if !s.len().is_multiple_of(2) {
return Err(format!("hex string has odd length {}", s.len()));
}
(0..s.len())
.step_by(2)
.map(|i| {
u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("bad hex byte at {i}: {e}"))
})
.collect()
}
#[cfg(not(feature = "cuda"))]
pub struct NcclState;
#[cfg(not(feature = "cuda"))]
impl Default for NcclState {
fn default() -> Self {
Self::new()
}
}
#[cfg(not(feature = "cuda"))]
impl NcclState {
pub fn new() -> Self {
Self
}
pub fn init(&mut self, _cfg: WorkerConfig, _comm_id_hex: &str) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "this neuron binary was built without --features cuda; \
NCCL Init requires CUDA"
.into(),
}
}
pub fn sanity_check(&mut self) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "NCCL sanity check requires --features cuda".into(),
}
}
}
#[cfg(feature = "cuda")]
mod cuda_impl {
use super::*;
use cudarc::driver::CudaContext;
use cudarc::nccl::{Comm, Id, ReduceOp};
use std::sync::Arc;
/// Number of bytes in NCCL's unique-id type; matches `Id::internal()`'s
/// `[c_char; 128]`. Wire-encoded as 256 lowercase hex chars.
const NCCL_ID_BYTES: usize = 128;
pub struct NcclState {
/// Wrapped in `Arc` so we can hand a clone to `TpQwen3ForCausalLM`
/// at load time (every row-parallel layer needs a reference to
/// run its trailing `AllReduce`). The `Arc` is the single source
/// of truth for the comm's lifetime — when the pool drops and
/// every layer that captured a clone drops, NCCL releases the
/// underlying `ncclComm_t`.
comm: Option<Arc<Comm>>,
/// Held alongside the Comm so the device isn't dropped
/// underneath the NCCL handle.
#[allow(dead_code)]
ctx: Option<Arc<CudaContext>>,
}
impl Default for NcclState {
fn default() -> Self {
Self::new()
}
}
impl NcclState {
pub fn new() -> Self {
Self {
comm: None,
ctx: None,
}
}
/// Clone the comm out as an `Arc` so callers (the leader-side
/// `TpQwen3ForCausalLM::load`, or the worker's own model load)
/// can hold a reference for the lifetime of the model. Returns
/// `None` before `init` has run.
pub fn comm(&self) -> Option<Arc<Comm>> {
self.comm.clone()
}
}
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
/// given comm must be serialised", not "the handle must stay on the
/// thread that created it" — so it's safe to move an `Arc<Comm>`
/// across threads as long as no concurrent ops are issued. The
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
/// wrapper at the move boundary is the only thing missing.
///
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
/// by the row-parallel layers are only used from the
/// `spawn_blocking` thread driving the forward pass; concurrent
/// access from another thread would still be a bug.
pub struct SendComm(pub Arc<Comm>);
// SAFETY: see the doc-comment above; the invariant is enforced at
// the call site (pool Mutex + single spawn_blocking thread), not at
// the type level.
unsafe impl Send for SendComm {}
unsafe impl Sync for SendComm {}
impl SendComm {
pub fn into_inner(self) -> Arc<Comm> {
self.0
}
}
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
// (libnccl-allocated state). NCCL requires that operations against
// one Comm be issued one at a time; we serialise access by storing
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
// move-safe — NCCL doesn't track the calling OS thread, only the
// stream the operations are dispatched against.
unsafe impl Send for NcclState {}
unsafe impl Sync for NcclState {}
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
/// the leader to mint the shared communicator id which is then
/// broadcast to every worker via the RPC `Init` message.
pub fn generate_comm_id_hex() -> Result<String, String> {
// NcclError lacks a Display impl in cudarc 0.19.x — surface
// via Debug throughout this module.
let id = Id::new().map_err(|e| format!("Id::new(): {e:?}"))?;
let bytes_u8: [u8; NCCL_ID_BYTES] = std::array::from_fn(|i| id.internal()[i] as u8);
Ok(encode_hex(&bytes_u8))
}
impl NcclState {
pub fn init(&mut self, cfg: WorkerConfig, comm_id_hex: &str) -> WorkerResponse {
match try_init(self, cfg, comm_id_hex) {
Ok(()) => WorkerResponse::InitOk,
Err(msg) => WorkerResponse::Error {
kind: "nccl_init_failed".into(),
message: msg,
},
}
}
pub fn sanity_check(&mut self) -> WorkerResponse {
let Some(comm) = self.comm.as_ref() else {
return WorkerResponse::Error {
kind: "nccl_not_initialised".into(),
message: "sanity_check requires Init to have completed first".into(),
};
};
match try_sanity_check(comm.as_ref()) {
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
Err(msg) => WorkerResponse::Error {
kind: "nccl_sanity_failed".into(),
message: msg,
},
}
}
}
fn try_init(state: &mut NcclState, cfg: WorkerConfig, comm_id_hex: &str) -> Result<(), String> {
let bytes = decode_hex(comm_id_hex)?;
if bytes.len() != NCCL_ID_BYTES {
return Err(format!(
"comm_id is {} bytes, expected {NCCL_ID_BYTES}",
bytes.len()
));
}
let id_bytes: [std::ffi::c_char; NCCL_ID_BYTES] =
std::array::from_fn(|i| bytes[i] as std::ffi::c_char);
let id = Id::uninit(id_bytes);
let ctx = CudaContext::new(cfg.cuda_device as usize)
.map_err(|e| format!("CudaContext::new({}) failed: {e}", cfg.cuda_device))?;
let stream = ctx.default_stream();
let comm = Comm::from_rank(stream, cfg.rank as usize, cfg.world_size as usize, id)
.map_err(|e| {
format!(
"Comm::from_rank(rank={}, world={}) failed: {e:?}",
cfg.rank, cfg.world_size
)
})?;
state.ctx = Some(ctx);
// `Comm` is !Send + !Sync at the type level because it wraps a
// raw `ncclComm_t`. The `Arc` is fine in practice — we
// serialise operations through the pool's outer Mutex and the
// SendComm wrapper at thread-crossing boundaries enforces this
// at every move site. clippy's `arc_with_non_send_sync` lint
// can't see that invariant; allow once at the canonical
// construction site.
#[allow(clippy::arc_with_non_send_sync)]
{
state.comm = Some(Arc::new(comm));
}
Ok(())
}
fn try_sanity_check(comm: &Comm) -> Result<u32, String> {
let stream = comm.stream().clone();
let input = stream
.clone_htod(&[1u32])
.map_err(|e| format!("htod sentinel: {e}"))?;
let mut output = stream
.alloc_zeros::<u32>(1)
.map_err(|e| format!("alloc output: {e}"))?;
// cudarc::nccl::NcclError doesn't impl Display in 0.19.x —
// surface via Debug so we still see the variant + ncclResult
// code instead of a generic "{e}" failure.
comm.all_reduce(&input, &mut output, &ReduceOp::Sum)
.map_err(|e| format!("all_reduce: {e:?}"))?;
let result = stream
.clone_dtoh(&output)
.map_err(|e| format!("dtoh result: {e}"))?;
Ok(result[0])
}
}
#[cfg(feature = "cuda")]
pub use cuda_impl::{NcclState, SendComm, generate_comm_id_hex};
/// Non-cuda stub for the leader: returns a clear marker error rather
/// than letting `init_nccl` succeed vacuously.
#[cfg(not(feature = "cuda"))]
pub fn generate_comm_id_hex() -> Result<String, String> {
Err("cuda_feature_not_enabled: build with --features cuda".into())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hex_roundtrip() {
let original: Vec<u8> = (0u8..=255).collect();
let encoded = encode_hex(&original);
assert_eq!(encoded.len(), 512);
let decoded = decode_hex(&encoded).expect("decode");
assert_eq!(decoded, original);
}
#[test]
fn hex_decode_rejects_odd_length() {
assert!(decode_hex("a").is_err());
assert!(decode_hex("abc").is_err());
}
#[test]
fn hex_decode_rejects_non_hex() {
assert!(decode_hex("zz").is_err());
assert!(decode_hex("ab_d").is_err());
}
#[test]
fn hex_encode_is_lowercase_padded() {
assert_eq!(encode_hex(&[0x0a, 0xff]), "0aff");
}
}

View File

@@ -0,0 +1,257 @@
//! Wire protocol between the neuron leader process and its
//! `--worker` subprocesses.
//!
//! Every frame is one newline-delimited JSON object on the worker's
//! stdin (request) or stdout (response). Both directions are tagged
//! sum types from the start so new ops in Stage 7b/7c slot in without
//! breaking compatibility — no "14 message types and a version field"
//! drift later. Adding a new variant is the canonical way to evolve
//! the protocol; existing peers that don't recognise an op return
//! `WorkerResponse::Error { kind: "unknown_op", .. }`.
//!
//! The serialised shape uses `tag = "op"` so a request looks like:
//! {"op":"ping"}
//! {"op":"init","comm_id":"a1b2..."}
//! and a response:
//! {"op":"pong","rank":0,"world_size":2,"cuda_device":0}
//! {"op":"error","kind":"nccl_init_failed","message":"..."}
use serde::{Deserialize, Serialize};
/// Leader → worker. Worker handles one at a time; replies with exactly
/// one `WorkerResponse` per request.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum WorkerRequest {
/// Liveness probe. Worker replies with `Pong` containing its own
/// identity. Used by the leader to confirm the subprocess is up
/// and ready before kicking off any heavier work.
Ping,
/// One-shot NCCL communicator setup. The leader generates the
/// `comm_id` once (rank 0 of NCCL), broadcasts it to every worker
/// via this message, then every rank (leader included) calls
/// `Comm::from_rank` with the same id — NCCL blocks until all
/// `world_size` ranks check in. The hex-encoded bytes are the
/// canonical `cudarc::nccl::Id::internal()` content.
Init {
/// Hex-encoded NCCL id bytes (128 bytes → 256 hex chars).
comm_id: String,
},
/// Sanity check: after Init, every rank runs an `all_reduce` over
/// a sentinel value (`1u32`). The expected sum is `world_size`.
/// Worker replies with the observed value so the leader can verify
/// the NCCL handshake is genuinely live, not just configured.
NcclSanityCheck,
/// Load this rank's shard of a dense Qwen3 model from mmaped
/// safetensors. The same `safetensors_paths` list is sent to every
/// rank — the ShardedVarBuilder reads only the rank-local slice of
/// each tensor at materialisation time, so the worker's VRAM
/// footprint is `1 / world_size` of the full model (plus replicated
/// embedding/norm/lm_head).
LoadDenseShard {
/// Caller-supplied id for later `GenerateStep` / `UnloadModel`
/// lookups. Typically the HF model id verbatim.
model_id: String,
/// JSON-serialised `candle_transformers::models::qwen3::Config`
/// — the same blob the leader parsed from the HF cache's
/// `config.json`. Threaded through verbatim so the worker uses
/// identical hyperparameters.
config_json: String,
/// Absolute paths the worker should mmap. The same set on every
/// rank; ShardedVarBuilder slices into them per rank.
safetensors_paths: Vec<String>,
/// Optional in-situ quantization dtype (e.g. "q5k", "q8_0",
/// "q6k"). When set, each linear-layer weight is quantized
/// at load time to the named ggml format — saves ~3-5x vs
/// bf16/f16 at the cost of some accuracy. `None` keeps the
/// weights in the on-disk dtype (typically bf16).
#[serde(default)]
quant: Option<String>,
},
/// Run one forward step on this rank's loaded model. The worker
/// reaches into its NCCL Comm for the row-parallel `AllReduce`s
/// inside the model — and so blocks on every other rank issuing the
/// same op. The leader does *not* receive logits back over RPC; it
/// runs its own rank-0 forward in parallel and uses its own logits
/// for sampling.
GenerateStep {
model_id: String,
/// Input token ids for this step. For prefill, the whole prompt;
/// for decode, a single token. Identical on every rank.
tokens: Vec<u32>,
/// KV cache offset (count of tokens already in the cache before
/// this step).
offset: usize,
},
/// Reset the KV cache for this model on this rank. Sent at the
/// start of every inference so a fresh request doesn't accidentally
/// attend over the previous one's tokens.
ClearKvCache { model_id: String },
/// Drop this rank's shard for the given model. Releases the VRAM
/// the shard's weights occupied; subsequent `GenerateStep` calls
/// against the same `model_id` return an `Error`.
UnloadModel { model_id: String },
/// Worker should release resources and exit. Worker replies `Bye`
/// and then closes stdout / exits zero. The leader reaps the
/// child via the `tokio::process::Child` it kept.
Shutdown,
}
/// Worker → leader. Always exactly one of these per `WorkerRequest`.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum WorkerResponse {
/// Reply to `Ping`. Carries enough identity for the leader to log
/// what it actually got back.
Pong {
rank: u32,
world_size: u32,
cuda_device: u32,
},
/// Reply to `Init`. Empty payload — success is the absence of
/// `Error`. NCCL's internal blocking handshake means by the time
/// this comes back, every other rank has also reached
/// `Comm::from_rank`.
InitOk,
/// Reply to `NcclSanityCheck`. The observed sum after a single
/// `all_reduce(SUM, 1u32)` across all ranks. The leader checks
/// this matches `world_size`.
NcclSanityResult { observed_sum: u32 },
/// Reply to `LoadDenseShard`. Empty payload — success is the
/// absence of `Error`. By the time this comes back, the rank's
/// `TpQwen3ForCausalLM` is constructed in memory and ready for
/// `GenerateStep`.
LoadDenseShardOk,
/// Reply to `GenerateStep`. Empty payload — workers don't ship
/// logits over the wire. The leader uses its own rank-0 logits;
/// workers only need to confirm the collective completed.
GenerateStepOk,
/// Reply to `ClearKvCache`. Empty payload.
KvCacheCleared,
/// Reply to `UnloadModel`. Empty payload. The named model is no
/// longer present on this rank.
Unloaded,
/// Reply to `Shutdown`. Worker exits immediately after writing this.
Bye,
/// Any request can produce this instead of its dedicated success
/// variant. `kind` is a machine-readable category so the leader
/// can branch on failure mode without string-matching `message`.
Error {
/// Short tag — `nccl_init_failed`, `unknown_op`, etc.
kind: String,
/// Human-readable detail for logs.
message: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrip<T>(value: &T) -> T
where
T: Serialize + for<'de> Deserialize<'de>,
{
serde_json::from_str(&serde_json::to_string(value).expect("serialise"))
.expect("deserialise")
}
#[test]
fn request_ping_round_trip() {
let req = WorkerRequest::Ping;
let wire = serde_json::to_string(&req).unwrap();
assert_eq!(wire, r#"{"op":"ping"}"#);
match roundtrip(&req) {
WorkerRequest::Ping => {}
other => panic!("expected Ping, got {other:?}"),
}
}
#[test]
fn request_init_carries_hex_id() {
let req = WorkerRequest::Init {
comm_id: "deadbeef".into(),
};
let wire = serde_json::to_string(&req).unwrap();
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
}
#[test]
fn request_shutdown_round_trip() {
assert_eq!(
serde_json::to_string(&WorkerRequest::Shutdown).unwrap(),
r#"{"op":"shutdown"}"#
);
}
#[test]
fn response_pong_round_trip() {
let resp = WorkerResponse::Pong {
rank: 1,
world_size: 4,
cuda_device: 1,
};
let wire = serde_json::to_string(&resp).unwrap();
assert!(wire.contains(r#""op":"pong""#));
assert!(wire.contains(r#""rank":1"#));
assert!(wire.contains(r#""world_size":4"#));
match roundtrip(&resp) {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
assert_eq!(rank, 1);
assert_eq!(world_size, 4);
assert_eq!(cuda_device, 1);
}
other => panic!("expected Pong, got {other:?}"),
}
}
#[test]
fn response_error_carries_kind_and_message() {
let resp = WorkerResponse::Error {
kind: "nccl_init_failed".into(),
message: "could not bind device".into(),
};
let wire = serde_json::to_string(&resp).unwrap();
assert!(wire.contains(r#""op":"error""#));
assert!(wire.contains(r#""kind":"nccl_init_failed""#));
}
#[test]
fn response_sanity_result_round_trip() {
let resp = WorkerResponse::NcclSanityResult { observed_sum: 4 };
match roundtrip(&resp) {
WorkerResponse::NcclSanityResult { observed_sum } => {
assert_eq!(observed_sum, 4);
}
other => panic!("expected NcclSanityResult, got {other:?}"),
}
}
/// Unknown ops on the wire deserialise to an error rather than
/// silently mis-matching — confirms our `serde(tag = "op")`
/// configuration rejects unknowns instead of doing fuzzy matching.
#[test]
fn unknown_op_fails_to_parse() {
let result: Result<WorkerRequest, _> = serde_json::from_str(r#"{"op":"explode"}"#);
assert!(result.is_err(), "should reject unknown op, got {result:?}");
}
}

View File

@@ -0,0 +1,283 @@
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
//! and `Shard` sharding hints.
//!
//! candle reads only the rank's slice of each weight tensor from
//! safetensors via `view.slice(start..stop)` — no full-tensor host
//! materialisation. That's a memory-efficiency win over hand-rolled
//! "load full + narrow" sharding (which the earlier
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
//!
//! Two layer types:
//!
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
//! local matmul. The downstream consumer either accepts a sharded
//! activation (next layer is also column-parallel) or all-gathers.
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
//!
//! Both assume **no bias** — every Qwen3-family weight layout we
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
//! support is mechanical when a future model needs it; the design
//! choice would be: column-parallel shards the bias along dim 0;
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
//! sum carries it exactly once.
use anyhow::{Context, Result};
use candle_core::quantized::{GgmlDType, QMatMul, QTensor};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
use std::sync::Arc;
#[cfg(feature = "cuda")]
use super::all_reduce::AllReduce;
/// Linear primitive that holds either a plain `Linear` (bf16/f16/f32)
/// or a quantized `QMatMul` (Q4K/Q5K/Q6K/Q8_0/etc.).
///
/// Constructed via [`MaybeQuantLinear::from_weight`] — pass `None` to
/// keep the weight in its loaded dtype (no quantization), or
/// `Some(dtype)` to quantize at load time.
///
/// On the forward path the two arms dispatch identically: `Module::forward`
/// returns an output in the caller's input dtype (f32 fallback for the
/// quantized matmul). Subsequent ops don't need to know whether the
/// layer was quantized.
pub enum MaybeQuantLinear {
Plain(Linear),
Quant(QMatMul),
}
impl MaybeQuantLinear {
/// Build a linear from a loaded weight tensor. If `quant` is set,
/// the weight is quantized in-situ and stored as a `QMatMul`;
/// otherwise it's wrapped in a plain `Linear`.
pub fn from_weight(weight: Tensor, quant: Option<GgmlDType>) -> Result<Self> {
match quant {
Some(dtype) => {
let qt = QTensor::quantize(&weight, dtype).with_context(|| {
format!(
"QTensor::quantize to {dtype:?} for shape {:?}",
weight.shape()
)
})?;
let qmm = QMatMul::from_arc(Arc::new(qt))
.context("QMatMul::from_arc on freshly quantized weight")?;
Ok(Self::Quant(qmm))
}
None => Ok(Self::Plain(Linear::new(weight, None))),
}
}
}
/// Above this M (the product of all input dims except the last)
/// dispatch the quantized matmul through `QMatMul::forward_via_f16`,
/// which dequantizes the weight to f16 once and runs cuBLAS GEMM.
/// At or below this M the GGUF GEMV kernel inside
/// `QMatMul::forward` wins (it operates on quantized blocks directly
/// and accumulates in registers).
///
/// Empirical: at M=30 on Qwen3.6-27B / RTX 5090, forward_via_f16 was
/// slightly *slower* than the GGUF GEMV kernel — the per-call dequant
/// cost (~30 MB f16 written to global memory per linear × ~480 calls
/// per prefill) eats the cuBLAS GEMM speedup at small M. The
/// crossover where the GEMM scaling actually beats the fixed dequant
/// tax sits well above M=8.
///
/// 64 is a conservative crossover that keeps short-prompt prefills
/// on the GGUF kernel (where the per-call cost is comparable to the
/// f16 path but the dequant tax is zero) and only activates the
/// dequant-then-GEMM path for long prefills where the GEMM size
/// makes amortising worth it. A proper fix is either a dequant
/// cache or a fused dequant+gemm cuda kernel — both larger projects.
const QUANT_PREFILL_M_THRESHOLD: usize = 64;
impl Module for MaybeQuantLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
match self {
Self::Plain(l) => l.forward(x),
Self::Quant(qm) => {
// Decode vs prefill split. `M` is the "rows of x" the
// matmul will iterate over — every dim except the last
// (which is in_features). For decode (`seq_len == 1`
// with batch 1) M is 1; for prefill with L>>1 it's L
// (or B*L).
let dims = x.dims();
let m: usize = dims.iter().take(dims.len() - 1).product();
if m > QUANT_PREFILL_M_THRESHOLD {
// Prefill: dequantize the weight once into f16,
// then run a real cuBLAS-backed GEMM. The cost of
// the dequant is amortised across all M tokens.
// `forward_via_f16` handles the dtype round-trip
// internally (output matches input dtype).
return qm.forward_via_f16(x);
}
// Decode (M <= threshold): use the on-the-fly GGUF
// GEMV kernel via `QMatMul::forward`. That kernel
// requires f32 inputs (it accumulates in f32 from the
// dequantized quant blocks); cast in/out at the
// boundary.
let in_dtype = x.dtype();
let x_f32 = if in_dtype == candle_core::DType::F32 {
x.clone()
} else {
x.to_dtype(candle_core::DType::F32)?
};
let y = qm.forward(&x_f32)?;
if y.dtype() == in_dtype {
Ok(y)
} else {
y.to_dtype(in_dtype)
}
}
}
}
}
/// Helper to build a [`Shard`] hint for a given dimension.
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
Shard {
dim,
rank: rank as usize,
world_size: world_size as usize,
}
}
/// Output-dim sharded linear (column-parallel). Holds a
/// [`MaybeQuantLinear`] whose underlying weight is this rank's slice
/// of the full `[out_features, in_features]` tensor along dim 0.
pub struct ColumnParallelLinear {
inner: MaybeQuantLinear,
}
impl ColumnParallelLinear {
/// Load this rank's column-parallel slice from a
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
///
/// Backward-compatible variant — no in-situ quantization. For
/// quantized loads, use [`Self::load_with_quant`].
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, None)
}
/// Like [`Self::load`] but quantizes the per-rank weight in-situ
/// when `quant` is `Some(dtype)`. Saves ~3-5x vs bf16/f16.
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(0, rank, world_size))
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap column-parallel '{}'", vb.prefix()))?;
Ok(Self { inner })
}
}
impl Module for ColumnParallelLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
self.inner.forward(x)
}
}
/// Input-dim sharded linear (row-parallel).
///
/// Holds a sharded [`MaybeQuantLinear`] plus an `AllReduce` op the
/// forward chains after the local matmul to recover the full activation.
pub struct RowParallelLinear {
inner: MaybeQuantLinear,
#[cfg(feature = "cuda")]
all_reduce: AllReduce,
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
/// is a pair: the column output is sharded, the row input is
/// sharded, and the AllReduce gives back the full output. For
/// `world_size = 1` the AllReduce is a no-op so we skip it.
needs_reduce: bool,
}
impl RowParallelLinear {
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
///
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
/// `AllReduce` runs against. On CPU builds the parameter is
/// elided — forward returns the partial sum, which is the *wrong*
/// answer for inference but lets us compile-check the model.
///
/// Backward-compatible variant — no in-situ quantization. For
/// quantized loads, use [`Self::load_with_quant`].
#[cfg(feature = "cuda")]
pub fn load(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: std::sync::Arc<cudarc::nccl::Comm>,
) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, comm, None)
}
/// Like [`Self::load`] but quantizes the per-rank weight in-situ.
#[cfg(feature = "cuda")]
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: std::sync::Arc<cudarc::nccl::Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
Ok(Self {
inner,
all_reduce: AllReduce::new(comm),
needs_reduce: world_size > 1,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, None)
}
#[cfg(not(feature = "cuda"))]
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
Ok(Self {
inner,
needs_reduce: world_size > 1,
})
}
}
impl Module for RowParallelLinear {
/// Local matmul followed by an `AllReduce` (when `cuda` and
/// `world_size > 1`). On CPU or single-rank, returns the partial
/// output directly — which is *only* correct for `world_size == 1`.
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let local = self.inner.forward(x)?;
#[cfg(feature = "cuda")]
if self.needs_reduce {
return local.apply_op1_no_bwd(&self.all_reduce);
}
let _ = self.needs_reduce;
Ok(local)
}
}

View File

@@ -0,0 +1,605 @@
//! Tensor-parallel Qwen3 dense model.
//!
//! Mirrors `candle_transformers::models::qwen3` structurally, but with:
//!
//! - Attention's `q_proj` / `k_proj` / `v_proj` as
//! [`ColumnParallelLinear`] (output sharded along the head dimension —
//! per-rank `num_heads = total/world_size`, ditto for kv heads).
//! - Attention's `o_proj` as [`RowParallelLinear`] (input sharded; the
//! trailing `AllReduce` recovers the full activation).
//! - MLP's `gate_proj` / `up_proj` as [`ColumnParallelLinear`] (sharded
//! along `intermediate_size`).
//! - MLP's `down_proj` as [`RowParallelLinear`].
//! - `embed_tokens`, all `RmsNorm`s, and `lm_head` **replicated** on
//! every rank. The per-rank duplicate weight is bounded and lets us
//! skip the embedding all-gather and the lm-head column-shard +
//! all-gather; both are pure latency optimisations that don't change
//! correctness.
//! - `kv_cache` holds the per-rank slice of K/V already (because they
//! came out of a column-parallel projection). No cache resharding
//! needed across ranks.
//!
//! Divisibility requirement, checked at load time:
//!
//! - `num_attention_heads % world_size == 0`
//! - `num_key_value_heads % world_size == 0`
//! - `intermediate_size % world_size == 0`
//!
//! Anything else bails — the safetensors slice would lose data otherwise.
//! This is the same divisibility-bail pattern that landed in
//! `EricLBuehler/mistral.rs` PR #2054.
//!
//! Replicated tensors (norms, embedding, lm_head) are loaded by asking
//! the `ShardedVarBuilder` for the full tensor via `vb.get(shape, name)`
//! — which defaults to `Shard { world_size: 1 }` and falls through to
//! the unsharded backend path.
use anyhow::{Context, Result, bail};
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use candle_nn::{Activation, Embedding, Linear, RmsNorm, kv_cache::ConcatKvCache};
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;
#[cfg(feature = "cuda")]
use cudarc::nccl::Comm;
use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
pub use candle_transformers::models::qwen3::Config;
/// Replicated rotary-embedding lookup. Re-implementation of the
/// `pub(crate)` candle equivalent — we can't reach into the upstream
/// type, so the inv-freq / sin / cos construction lives here.
pub(crate) struct Qwen3RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl Qwen3RotaryEmbedding {
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
})
}
fn apply(
&self,
q: &Tensor,
k: &Tensor,
offset: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
/// Helper: load a replicated tensor by asking the ShardedVarBuilder for
/// the full tensor (world_size=1 hint falls through to SimpleBackend).
fn load_replicated<S: Into<candle_core::Shape>>(
vb: &ShardedVarBuilder,
shape: S,
name: &str,
) -> Result<Tensor> {
vb.get(shape, name)
.with_context(|| format!("load replicated '{}/{name}'", vb.prefix()))
}
fn load_rms_norm(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<RmsNorm> {
let weight = load_replicated(vb, size, "weight")?;
Ok(RmsNorm::new(weight, eps))
}
/// TP MLP. SwiGLU = `down(silu(gate(x)) * up(x))`.
pub(crate) struct TpQwen3MLP {
gate_proj: ColumnParallelLinear,
up_proj: ColumnParallelLinear,
down_proj: RowParallelLinear,
act_fn: Activation,
}
impl TpQwen3MLP {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
bail!(
"intermediate_size {} not divisible by world_size {}",
cfg.intermediate_size,
world_size
);
}
Ok(Self {
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?,
act_fn: cfg.hidden_act,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
bail!(
"intermediate_size {} not divisible by world_size {}",
cfg.intermediate_size,
world_size
);
}
Ok(Self {
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?,
act_fn: cfg.hidden_act,
})
}
}
impl Module for TpQwen3MLP {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = x.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
/// TP attention. Carries per-rank head counts and the q/k per-head
/// RmsNorms (which are replicated and operate on a flattened B*H*L
/// axis, so the same code path works irrespective of how H was split).
pub(crate) struct TpQwen3Attention {
q_proj: ColumnParallelLinear,
k_proj: ColumnParallelLinear,
v_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
q_norm: RmsNorm,
k_norm: RmsNorm,
local_num_heads: usize,
local_num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
local_hidden_size: usize,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl TpQwen3Attention {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
Self::load_inner(
cfg,
rotary_emb,
vb,
rank,
world_size,
#[cfg(feature = "cuda")]
comm,
)
}
#[cfg(not(feature = "cuda"))]
pub fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
) -> Result<Self> {
Self::load_inner(cfg, rotary_emb, vb, rank, world_size)
}
fn load_inner(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
#[cfg(feature = "cuda")] comm: Arc<Comm>,
) -> Result<Self> {
if cfg.use_sliding_window {
bail!("sliding window is not yet supported in the TP path");
}
if cfg.attention_bias {
bail!("attention_bias=true is not supported by ColumnParallel/RowParallelLinear yet");
}
let ws = world_size as usize;
if !cfg.num_attention_heads.is_multiple_of(ws) {
bail!(
"num_attention_heads {} not divisible by world_size {}",
cfg.num_attention_heads,
world_size
);
}
if !cfg.num_key_value_heads.is_multiple_of(ws) {
bail!(
"num_key_value_heads {} not divisible by world_size {}",
cfg.num_key_value_heads,
world_size
);
}
let head_dim = cfg.head_dim;
let local_num_heads = cfg.num_attention_heads / ws;
let local_num_kv_heads = cfg.num_key_value_heads / ws;
let num_kv_groups = local_num_heads / local_num_kv_heads;
let local_hidden_size = head_dim * local_num_heads;
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?;
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?;
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?;
#[cfg(feature = "cuda")]
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?;
#[cfg(not(feature = "cuda"))]
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?;
let q_norm = load_rms_norm(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
let k_norm = load_rms_norm(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
// dim=2 because we cat along the seq axis of (B, H, L, D) tensors.
let kv_cache = ConcatKvCache::new(2);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
local_num_heads,
local_num_kv_heads,
num_kv_groups,
head_dim,
local_hidden_size,
rotary_emb,
kv_cache,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
// 1. Projections (column-parallel → output is sharded).
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
// 2. Reshape: (B, L, H, D) → (B, H, L, D).
let q = q
.reshape((b, l, self.local_num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
// 3. Per-head RmsNorm (replicated weight, flat input).
let q_flat = q.flatten(0, 2)?;
let k_flat = k.flatten(0, 2)?;
let q_flat = self.q_norm.forward(&q_flat)?;
let k_flat = self.k_norm.forward(&k_flat)?;
let q = q_flat.reshape((b, self.local_num_heads, l, self.head_dim))?;
let k = k_flat.reshape((b, self.local_num_kv_heads, l, self.head_dim))?;
// 4. Rotary.
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
// 5. Accumulate KV.
let (k, v) = self.kv_cache.append(&k, &v)?;
// 6. GQA repeat_kv on the rank-local K/V.
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
// 7. Attention scores.
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?;
// 8. Output projection (row-parallel → AllReduce inside).
ctx.transpose(1, 2)?
.reshape((b, l, self.local_hidden_size))?
.apply(&self.o_proj)
}
pub fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
}
struct TpDecoderLayer {
self_attn: TpQwen3Attention,
mlp: TpQwen3MLP,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl TpDecoderLayer {
#[cfg(feature = "cuda")]
fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
let self_attn = TpQwen3Attention::load(
cfg,
rotary_emb,
&vb.pp("self_attn"),
rank,
world_size,
comm.clone(),
)?;
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?;
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let ln2 = load_rms_norm(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
#[cfg(not(feature = "cuda"))]
fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
) -> Result<Self> {
let self_attn =
TpQwen3Attention::load(cfg, rotary_emb, &vb.pp("self_attn"), rank, world_size)?;
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?;
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let ln2 = load_rms_norm(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
fn forward(
&mut self,
x: &Tensor,
mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let h = self.ln1.forward(x)?;
let h = self.self_attn.forward(&h, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
let h2 = h2.apply(&self.mlp)?;
x + h2
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
}
/// Base TP Qwen3 transformer — embedding, decoder stack, final norm.
/// The lm_head sits on top in [`TpQwen3ForCausalLM`].
pub struct TpQwen3Model {
embed_tokens: Embedding,
layers: Vec<TpDecoderLayer>,
norm: RmsNorm,
device: Device,
dtype: DType,
}
impl TpQwen3Model {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
let embed_vb = vb.pp("model.embed_tokens");
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let vb_l = vb.pp("model.layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(TpDecoderLayer::load(
cfg,
rotary.clone(),
&vb_l.pp(i),
rank,
world_size,
comm.clone(),
)?);
}
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
let embed_vb = vb.pp("model.embed_tokens");
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let vb_l = vb.pp("model.layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(TpDecoderLayer::load(
cfg,
rotary.clone(),
&vb_l.pp(i),
rank,
world_size,
)?);
}
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
pub fn embed_weight(&self) -> &Tensor {
self.embed_tokens.embeddings()
}
pub fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
/// TP Qwen3 with a (replicated) language-model head on top.
pub struct TpQwen3ForCausalLM {
base: TpQwen3Model,
lm_head: Linear,
}
impl TpQwen3ForCausalLM {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
let base = TpQwen3Model::load(cfg, vb, rank, world_size, comm)?;
let lm_head = build_lm_head(cfg, vb, &base)?;
Ok(Self { base, lm_head })
}
#[cfg(not(feature = "cuda"))]
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
let base = TpQwen3Model::load(cfg, vb, rank, world_size)?;
let lm_head = build_lm_head(cfg, vb, &base)?;
Ok(Self { base, lm_head })
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self.base.forward(input, offset)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
pub fn device(&self) -> &Device {
&self.base.device
}
pub fn dtype(&self) -> DType {
self.base.dtype
}
}
fn build_lm_head(cfg: &Config, vb: &ShardedVarBuilder, base: &TpQwen3Model) -> Result<Linear> {
if cfg.tie_word_embeddings {
Ok(Linear::new(base.embed_weight().clone(), None))
} else {
let weight = load_replicated(
&vb.pp("lm_head"),
(cfg.vocab_size, cfg.hidden_size),
"weight",
)?;
Ok(Linear::new(weight, None))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,502 @@
//! Entry point for `neuron --worker`.
//!
//! The worker reads one newline-delimited JSON `WorkerRequest` from
//! stdin per loop iteration, dispatches synchronously, and writes
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
//! stderr so it doesn't collide with the RPC stream.
//!
//! NCCL operations (`Init`, `NcclSanityCheck`) and model lifecycle ops
//! (`LoadDenseShard`, `GenerateStep`, `ClearKvCache`, `UnloadModel`)
//! are real when built with the `cuda` feature; without it they reply
//! with `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
//! the difference between a misconfigured build and a genuine NCCL or
//! model failure.
use anyhow::Result;
use std::collections::HashMap;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use super::nccl_state::NcclState;
use super::rpc::{WorkerRequest, WorkerResponse};
#[cfg(feature = "cuda")]
use super::tp_qwen3::TpQwen3ForCausalLM;
#[cfg(feature = "cuda")]
use super::tp_qwen3_5::TpQwen3_5ForCausalLM;
/// Worker-side discriminator over the architectures we can load via
/// `LoadDenseShard`. Mirrors `super::TpLeaderModel` on the leader
/// side — the dispatch happens on the `model_type` extracted from the
/// config JSON.
#[cfg(feature = "cuda")]
enum WorkerModel {
Qwen3(TpQwen3ForCausalLM),
Qwen3_5(TpQwen3_5ForCausalLM),
}
#[cfg(feature = "cuda")]
impl WorkerModel {
fn forward(
&mut self,
input: &candle_core::Tensor,
offset: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
WorkerModel::Qwen3(m) => m.forward(input, offset),
WorkerModel::Qwen3_5(m) => m.forward(input, offset),
}
}
fn clear_kv_cache(&mut self) {
match self {
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
WorkerModel::Qwen3_5(m) => m.clear_kv_cache(),
}
}
fn device(&self) -> &candle_core::Device {
match self {
WorkerModel::Qwen3(m) => m.device(),
WorkerModel::Qwen3_5(m) => m.device(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct WorkerConfig {
pub rank: u32,
pub world_size: u32,
pub cuda_device: u32,
}
/// Drive the worker RPC loop until `Shutdown` or EOF on stdin.
pub async fn run(config: WorkerConfig) -> Result<()> {
tracing::info!(
rank = config.rank,
world_size = config.world_size,
cuda_device = config.cuda_device,
"tp worker starting"
);
let mut state = WorkerState::new(config);
let stdin = tokio::io::stdin();
let mut reader = BufReader::new(stdin).lines();
let mut stdout = tokio::io::stdout();
while let Some(line) = reader.next_line().await? {
if line.trim().is_empty() {
continue;
}
let req: WorkerRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
let resp = WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse {line:?}: {e}"),
};
write_response(&mut stdout, &resp).await?;
continue;
}
};
let resp = state.handle(req).await;
let is_bye = matches!(resp, WorkerResponse::Bye);
write_response(&mut stdout, &resp).await?;
if is_bye {
break;
}
}
tracing::info!(rank = config.rank, "tp worker exiting");
Ok(())
}
async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -> Result<()> {
let mut line = serde_json::to_string(resp)?;
line.push('\n');
stdout.write_all(line.as_bytes()).await?;
stdout.flush().await?;
Ok(())
}
/// One rank's local state. Owns the rank's NCCL communicator (via
/// `NcclState`) and the rank's shard of every loaded model.
struct WorkerState {
config: WorkerConfig,
nccl: NcclState,
/// Loaded model shards keyed by `model_id`. Each entry wraps the
/// rank's TP architecture handle (Qwen3 or Qwen3-Next) — the
/// column/row-parallel layers hold an `Arc<Comm>` cloned from
/// `nccl`. Cuda-only: the underlying types reference cudarc types
/// that don't exist without the cuda feature.
#[cfg(feature = "cuda")]
models: HashMap<String, WorkerModel>,
/// Placeholder so the non-cuda build keeps the same field name set
/// and `WorkerState::new` reads the same on both.
#[cfg(not(feature = "cuda"))]
#[allow(dead_code)]
models: HashMap<String, ()>,
}
impl WorkerState {
fn new(config: WorkerConfig) -> Self {
Self {
config,
nccl: NcclState::new(),
models: HashMap::new(),
}
}
async fn handle(&mut self, req: WorkerRequest) -> WorkerResponse {
match req {
WorkerRequest::Ping => WorkerResponse::Pong {
rank: self.config.rank,
world_size: self.config.world_size,
cuda_device: self.config.cuda_device,
},
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
WorkerRequest::LoadDenseShard {
model_id,
config_json,
safetensors_paths,
quant,
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant),
WorkerRequest::GenerateStep {
model_id,
tokens,
offset,
} => self.handle_generate_step(&model_id, tokens, offset),
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
WorkerRequest::Shutdown => WorkerResponse::Bye,
}
}
#[cfg(feature = "cuda")]
fn handle_load_dense_shard(
&mut self,
model_id: String,
config_json: String,
safetensors_paths: Vec<String>,
quant: Option<String>,
) -> WorkerResponse {
use crate::harness::arch::qwen3_5 as qwen3_5_arch;
use candle_core::{DType, Device};
use candle_nn::var_builder::ShardedSafeTensors;
use candle_transformers::models::qwen3 as qwen3_dense;
use std::path::PathBuf;
let quant_dtype = match parse_quant_string(quant.as_deref()) {
Ok(q) => q,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse quant: {e}"),
};
}
};
if self.models.contains_key(&model_id) {
return WorkerResponse::Error {
kind: "already_loaded".into(),
message: format!("model '{model_id}' already loaded on this rank"),
};
}
let comm = match self.nccl.comm() {
Some(c) => c,
None => {
return WorkerResponse::Error {
kind: "nccl_not_initialised".into(),
message: "LoadDenseShard requires Init to have completed first".into(),
};
}
};
// Peek at model_type so we know which architecture to build.
let model_type = serde_json::from_str::<serde_json::Value>(&config_json)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let device = match Device::new_cuda(self.config.cuda_device as usize) {
Ok(d) => d,
Err(e) => {
return WorkerResponse::Error {
kind: "cuda_unavailable".into(),
message: format!("Device::new_cuda({}) failed: {e}", self.config.cuda_device),
};
}
};
let paths: Vec<PathBuf> = safetensors_paths.into_iter().map(PathBuf::from).collect();
// SAFETY: same invariant as the single-GPU dense path — the HF
// cache files are treated as immutable while the mmap is held.
let vb = match unsafe { ShardedSafeTensors::var_builder(&paths, DType::BF16, &device) } {
Ok(v) => v,
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("ShardedSafeTensors::var_builder: {e}"),
};
}
};
// Separate mmap of the same paths for the direct fused-region
// loader in `fused_load`. Linux's page cache shares the
// underlying pages between the two mmaps; the cost is one
// extra set of safetensors-header parses.
let mmap = match unsafe { candle_core::safetensors::MmapedSafetensors::multi(&paths) } {
Ok(m) => m,
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("MmapedSafetensors::multi: {e}"),
};
}
};
let loaded = match model_type.as_str() {
"qwen3" => {
let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) {
Ok(c) => c,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse Qwen3 Config JSON: {e}"),
};
}
};
match TpQwen3ForCausalLM::load(
&cfg,
&vb,
self.config.rank,
self.config.world_size,
comm,
) {
Ok(m) => WorkerModel::Qwen3(m),
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("TpQwen3ForCausalLM::load: {e:#}"),
};
}
}
}
"qwen3_5" => {
let cfg: qwen3_5_arch::Config = match serde_json::from_str(&config_json) {
Ok(c) => c,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse Qwen3-Next Config JSON: {e}"),
};
}
};
match TpQwen3_5ForCausalLM::load(
cfg,
&vb,
&mmap,
self.config.rank,
self.config.world_size,
comm,
quant_dtype,
) {
Ok(m) => WorkerModel::Qwen3_5(m),
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("TpQwen3_5ForCausalLM::load: {e:#}"),
};
}
}
}
other => {
return WorkerResponse::Error {
kind: "unsupported_arch".into(),
message: format!(
"worker: unsupported model_type '{other}' (supported: qwen3, qwen3_5)"
),
};
}
};
self.models.insert(model_id.clone(), loaded);
tracing::info!(
rank = self.config.rank,
model = %model_id,
model_type = %model_type,
"loaded TP shard"
);
WorkerResponse::LoadDenseShardOk
}
#[cfg(not(feature = "cuda"))]
fn handle_load_dense_shard(
&mut self,
_model_id: String,
_config_json: String,
_safetensors_paths: Vec<String>,
_quant: Option<String>,
) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "LoadDenseShard requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_generate_step(
&mut self,
model_id: &str,
tokens: Vec<u32>,
offset: usize,
) -> WorkerResponse {
use candle_core::Tensor;
let Some(model) = self.models.get_mut(model_id) else {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
};
let device = model.device().clone();
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => {
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("build input tensor: {e}"),
};
}
};
let start = std::time::Instant::now();
tracing::debug!(
rank = self.config.rank,
model = %model_id,
tokens = tokens.len(),
offset,
"worker GenerateStep: forward starting"
);
// Drop the resulting logits — the leader uses its own copy from
// rank 0. The forward's value here is the NCCL collectives it
// issues, which let the leader's rank-0 forward make progress.
if let Err(e) = model.forward(&input, offset) {
tracing::warn!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
error = %e,
"worker GenerateStep: forward failed"
);
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("TP forward: {e}"),
};
}
tracing::debug!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
"worker GenerateStep: forward done"
);
WorkerResponse::GenerateStepOk
}
#[cfg(not(feature = "cuda"))]
fn handle_generate_step(
&mut self,
_model_id: &str,
_tokens: Vec<u32>,
_offset: usize,
) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "GenerateStep requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
let Some(model) = self.models.get_mut(model_id) else {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
};
model.clear_kv_cache();
WorkerResponse::KvCacheCleared
}
#[cfg(not(feature = "cuda"))]
fn handle_clear_kv_cache(&mut self, _model_id: &str) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "ClearKvCache requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_unload_model(&mut self, model_id: &str) -> WorkerResponse {
if self.models.remove(model_id).is_none() {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
}
tracing::info!(rank = self.config.rank, model = %model_id, "unloaded TP shard");
WorkerResponse::Unloaded
}
#[cfg(not(feature = "cuda"))]
fn handle_unload_model(&mut self, _model_id: &str) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "UnloadModel requires --features cuda".into(),
}
}
}
/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the
/// common ggml format names (case-insensitive). `None` and `Some("")`
/// both map to "no quantization".
///
/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`,
/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`,
/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore
/// is optional and the prefix is case-insensitive.
#[cfg(feature = "cuda")]
pub(crate) fn parse_quant_string(
s: Option<&str>,
) -> anyhow::Result<Option<candle_core::quantized::GgmlDType>> {
use candle_core::quantized::GgmlDType;
let s = match s {
Some(s) if !s.is_empty() => s,
_ => return Ok(None),
};
let normalised = s.to_ascii_lowercase().replace('_', "");
let dtype = match normalised.as_str() {
"q40" => GgmlDType::Q4_0,
"q41" => GgmlDType::Q4_1,
"q50" => GgmlDType::Q5_0,
"q51" => GgmlDType::Q5_1,
"q80" => GgmlDType::Q8_0,
"q81" => GgmlDType::Q8_1,
"q2k" => GgmlDType::Q2K,
"q3k" => GgmlDType::Q3K,
"q4k" | "q4km" => GgmlDType::Q4K,
"q5k" | "q5km" => GgmlDType::Q5K,
"q6k" => GgmlDType::Q6K,
"q8k" => GgmlDType::Q8K,
"f16" => GgmlDType::F16,
"bf16" => GgmlDType::BF16,
"f32" => GgmlDType::F32,
other => anyhow::bail!(
"unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \
q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)"
),
};
Ok(Some(dtype))
}

View File

@@ -1,5 +1,6 @@
pub mod api; pub mod api;
pub mod config; pub mod config;
pub mod cuda;
pub mod discovery; pub mod discovery;
pub mod harness; pub mod harness;
pub mod health; pub mod health;

View File

@@ -1,21 +1,66 @@
use anyhow::Result; use anyhow::{Context, Result};
use clap::Parser; use clap::Parser;
use neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health, startup}; use neuron::{
api,
config::NeuronConfig,
discovery,
harness::{HarnessRegistry, tp},
health, startup,
};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
/// Top-level CLI. The same binary runs as either the public neuron
/// daemon (default), a tensor-parallel worker subprocess (when
/// `--worker` is set, spawned by the leader on the same host), or a
/// one-shot TP NCCL handshake check (when `--tp-smoke` is set).
#[derive(Parser)] #[derive(Parser)]
#[command(name = "neuron")] #[command(name = "neuron")]
#[command(about = "Per-node daemon for cortex inference clusters")] #[command(about = "Per-node daemon for cortex inference clusters")]
#[command(version)] #[command(version)]
struct Args { struct Args {
/// Port to listen on (overrides config file). /// Run in tensor-parallel worker mode. The leader process spawns
/// one of these per non-zero NCCL rank and drives it over
/// newline-delimited JSON on stdin/stdout. Worker mode skips
/// discovery, the HTTP listener, and the health poller — it's a
/// pure RPC loop.
#[arg(long, default_value_t = false)]
worker: bool,
/// Run a one-shot TP smoke test: spawn `--tp-size - 1` worker
/// subprocesses on `--cuda-devices`, build the NCCL communicator,
/// run an `AllReduce` sanity check across every rank, and exit.
/// Used to validate the TP plumbing in isolation from model load
/// and inference. Diagnostic-only — not exposed through the daemon
/// HTTP API.
#[arg(long, default_value_t = false)]
tp_smoke: bool,
/// NCCL rank for worker mode. Ignored when `--worker` is not set.
#[arg(long, default_value_t = 0)]
rank: u32,
/// Total NCCL world size for worker mode or TP smoke mode.
#[arg(long, default_value_t = 1)]
tp_size: u32,
/// CUDA device index for worker mode. Ignored when `--worker` is
/// not set.
#[arg(long, default_value_t = 0)]
cuda_device: u32,
/// Comma-separated CUDA device indices for TP smoke mode (one per
/// rank, starting with rank 0). Must have `tp_size` entries.
#[arg(long, value_delimiter = ',')]
cuda_devices: Vec<u32>,
/// Port to listen on (overrides config file). Daemon mode only.
#[arg(short, long)] #[arg(short, long)]
port: Option<u16>, port: Option<u16>,
/// Path to the neuron config file. /// Path to the neuron config file. Daemon mode only.
#[arg(short, long, default_value = "neuron.toml")] #[arg(short, long, default_value = "neuron.toml")]
config: String, config: String,
} }
@@ -23,6 +68,7 @@ struct Args {
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter( .with_env_filter(
EnvFilter::try_from_default_env() EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")), .unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
@@ -31,6 +77,78 @@ async fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
if args.worker {
return tp::worker::run(tp::worker::WorkerConfig {
rank: args.rank,
world_size: args.tp_size,
cuda_device: args.cuda_device,
})
.await;
}
if args.tp_smoke {
return tp_smoke(args.tp_size, args.cuda_devices).await;
}
daemon(args).await
}
/// One-shot tensor-parallel handshake. Spawns N-1 worker subprocesses
/// (rank 0 stays in this process), builds the NCCL communicator across
/// the full world, runs an AllReduce sanity check, and shuts everyone
/// down. Output is plain log lines on stderr + a final summary on
/// stdout in `key=value` form so an outer script can parse it.
async fn tp_smoke(tp_size: u32, cuda_devices: Vec<u32>) -> Result<()> {
if tp_size < 2 {
anyhow::bail!("--tp-size must be at least 2 (got {tp_size})");
}
if cuda_devices.len() as u32 != tp_size {
anyhow::bail!(
"--cuda-devices must list exactly {tp_size} entries (got {})",
cuda_devices.len()
);
}
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
let leader_device = cuda_devices[0];
tracing::info!(
tp_size,
?cuda_devices,
binary = %exe.display(),
"tp-smoke: spawning worker pool"
);
let mut pool = tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices).await?;
tracing::info!("tp-smoke: pinging every worker");
let pongs = pool.ping_all().await?;
for p in &pongs {
tracing::info!(?p, "tp-smoke: pong");
}
tracing::info!(leader_device, "tp-smoke: initialising NCCL");
pool.init_nccl(leader_device).await?;
tracing::info!("tp-smoke: running AllReduce sanity check");
pool.nccl_sanity_check().await?;
tracing::info!("tp-smoke: shutting down pool");
pool.shutdown().await?;
println!("status=ok");
println!("tp_size={tp_size}");
println!(
"cuda_devices={}",
cuda_devices
.iter()
.map(|d| d.to_string())
.collect::<Vec<_>>()
.join(",")
);
Ok(())
}
async fn daemon(args: Args) -> Result<()> {
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| { let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults"); tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
NeuronConfig::default() NeuronConfig::default()

View File

@@ -0,0 +1,145 @@
//! Stage 7a-i: confirm the TP worker subprocess lifecycle round-trips.
//!
//! Spawns two worker subprocesses via the leader→worker stdio RPC,
//! pings each, and cleanly shuts them down. No CUDA required —
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
//! runs on any host the workspace builds on.
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
/// Path to the neuron binary built by cargo for this test process.
/// cargo populates `CARGO_BIN_EXE_neuron` at compile time for sibling-
/// binary tests; production paths in main.rs use `/proc/self/exe`.
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
/// Two workers (so we spawn one subprocess: rank 0 is in-process,
/// rank 1 is the child). Verify the spawned worker responds to Ping
/// with its own identity, then shut it down cleanly.
#[tokio::test]
async fn test_spawn_ping_shutdown() {
// cuda_devices: rank 0 → device 0 (leader, unused here),
// rank 1 → device 1 (worker; not actually opened in 7a-i).
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
.await
.expect("spawn worker pool");
let pongs = pool.ping_all().await.expect("ping all workers");
assert_eq!(pongs.len(), 1, "expected one Pong (rank 1 only)");
match &pongs[0] {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
assert_eq!(*rank, 1);
assert_eq!(*world_size, 2);
assert_eq!(*cuda_device, 1);
}
other => panic!("expected Pong, got {other:?}"),
}
pool.shutdown().await.expect("clean shutdown");
}
/// Three workers — exercise the loop in `ping_all` / `shutdown`.
#[tokio::test]
async fn test_spawn_three_workers() {
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2])
.await
.expect("spawn worker pool");
let pongs = pool.ping_all().await.expect("ping all workers");
assert_eq!(pongs.len(), 2, "expected two Pongs (ranks 1 and 2)");
for (i, resp) in pongs.iter().enumerate() {
match resp {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
let expected_rank = (i + 1) as u32;
assert_eq!(*rank, expected_rank);
assert_eq!(*world_size, 3);
assert_eq!(*cuda_device, expected_rank);
}
other => panic!("expected Pong, got {other:?}"),
}
}
pool.shutdown().await.expect("clean shutdown");
}
/// 7a-ii: without the cuda feature, Init must fail with a clear
/// `cuda_feature_not_enabled` marker rather than silently succeeding.
/// This is the local-dev-box test; the real NCCL handshake is exercised
/// by `tp_worker_lifecycle_cuda.rs` (gated on `cuda-integration`).
#[tokio::test]
async fn test_init_returns_cuda_feature_not_enabled_without_cuda() {
use neuron::harness::tp::rpc::WorkerRequest;
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
// Spawn a single worker by hand to send Init directly (the pool's
// public API doesn't expose Init yet — that lands in 7a-ii).
let mut child = Command::new(NEURON_BIN)
.arg("--worker")
.arg("--rank")
.arg("1")
.arg("--tp-size")
.arg("2")
.arg("--cuda-device")
.arg("1")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true)
.spawn()
.expect("spawn worker");
let mut stdin = child.stdin.take().expect("stdin");
let stdout = child.stdout.take().expect("stdout");
let mut lines = BufReader::new(stdout).lines();
let req = WorkerRequest::Init {
comm_id: "ff".repeat(128),
};
let mut payload = serde_json::to_string(&req).unwrap();
payload.push('\n');
stdin.write_all(payload.as_bytes()).await.unwrap();
stdin.flush().await.unwrap();
let reply = lines
.next_line()
.await
.expect("read line")
.expect("got line");
let resp: WorkerResponse = serde_json::from_str(&reply).expect("parse reply");
match resp {
WorkerResponse::Error { kind, .. } => {
#[cfg(feature = "cuda")]
{
// With cuda enabled the response depends on whether
// CUDA hardware is actually present. Accept either
// the success contract or a real NCCL failure.
let _ = kind;
}
#[cfg(not(feature = "cuda"))]
assert_eq!(kind, "cuda_feature_not_enabled");
}
WorkerResponse::InitOk => {
// Real NCCL succeeded — only possible with cuda feature
// AND a working NCCL stack AND another rank actually
// joining. Don't fail; just acknowledge.
#[cfg(not(feature = "cuda"))]
panic!("InitOk without cuda feature is impossible");
}
other => panic!("expected Error or InitOk, got {other:?}"),
}
// Clean shutdown.
stdin.write_all(b"{\"op\":\"shutdown\"}\n").await.unwrap();
stdin.flush().await.unwrap();
let _ = lines.next_line().await; // Bye
let _ = child.wait().await;
}

View File

@@ -0,0 +1,43 @@
//! Stage 7a-ii: real NCCL handshake across the worker pool.
//!
//! Gated behind the `cuda-integration` feature because it requires
//! libnccl AND multiple CUDA devices on the running host. Run on
//! beast (2× RTX 5090) via:
//!
//! cargo test -p neuron --features cuda-integration \
//! --test tp_worker_lifecycle_cuda
//!
//! Steps: spawn N-1 workers, call `init_nccl`, run `nccl_sanity_check`
//! (every rank `all_reduce`s `1u32` with Sum; expected total =
//! world_size), shut down cleanly.
#![cfg(feature = "cuda-integration")]
use neuron::harness::tp::WorkerPool;
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
#[tokio::test]
async fn test_init_and_sanity_check_two_ranks() {
let _ = tracing_subscriber::fmt()
.with_test_writer()
.with_env_filter("info,neuron=debug")
.try_init();
// 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1.
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
.await
.expect("spawn worker pool");
pool.ping_all().await.expect("pong all workers");
pool.init_nccl(0)
.await
.expect("init_nccl: NCCL handshake across all ranks");
pool.nccl_sanity_check()
.await
.expect("nccl_sanity_check: observed_sum == world_size on all ranks");
pool.shutdown().await.expect("clean shutdown");
}

View File

@@ -2,28 +2,49 @@
# #
# Copy to /etc/cortex/models.toml and adjust for your environment. # Copy to /etc/cortex/models.toml and adjust for your environment.
# Describes how to serve each model. Cortex matches these profiles # Describes how to serve each model. Cortex matches these profiles
# against discovered neuron topologies for placement decisions. # against discovered neuron topologies for placement decisions; the
# resulting `(catalogue × topology)` set is what `GET /v1/models`
# returns and what the router can cold-load on demand.
#
# Field reference:
# id - HuggingFace model id, exact match.
# harness - which engine handles inference (currently "candle").
# quant - GGUF quantisation tag for the file in the HF repo
# (e.g. "Q4_K_M"). Omit/empty for the dense
# safetensors path. TP requires dense.
# vram_mb - rough estimate; advisory only, not enforced.
# min_devices - GPU count this profile needs. TP profiles use
# the same value as the tensor-parallel size.
# min_device_vram_mb - each device must meet this VRAM floor for the
# neuron to be considered "feasible".
# pinned_on - optional whitelist of neuron names. Non-empty
# narrows feasibility to just those neurons and
# protects the model from LRU eviction there.
# Tensor-parallel target — needs a neuron with at least 2 large GPUs.
# The example pins to a specific neuron name; adjust or remove the
# pinned_on entry for your own fleet.
[[models]] [[models]]
id = "your-org/large-model" id = "Qwen/Qwen3.6-27B"
harness = "candle"
vram_mb = 54000
min_devices = 2
min_device_vram_mb = 24000
pinned_on = ["your-multi-gpu-neuron"]
# Mid-size dense model — fits on any single GPU with ≥16 GB VRAM.
[[models]]
id = "Qwen/Qwen3-8B"
harness = "candle"
vram_mb = 18000
min_devices = 1
min_device_vram_mb = 16000
# Small GGUF quantised — runs on any small GPU.
[[models]]
id = "unsloth/Qwen3-0.6B-GGUF"
harness = "candle" harness = "candle"
quant = "Q4_K_M" quant = "Q4_K_M"
vram_mb = 19000 vram_mb = 500
min_devices = 2
min_device_vram_mb = 10000
pinned_on = ["gpu-large"]
[[models]]
id = "your-org/medium-model"
harness = "candle"
quant = "Q6_K"
vram_mb = 12000
min_devices = 1
pinned_on = ["gpu-medium"]
[[models]]
id = "your-org/embedding-model"
harness = "candle"
quant = "Q8_0"
vram_mb = 8000
min_devices = 1 min_devices = 1
min_device_vram_mb = 4000

View File

@@ -19,8 +19,21 @@ name = "candle"
# Optional tuning for the candle harness. # Optional tuning for the candle harness.
[harness.candle] [harness.candle]
# HuggingFace cache directory for model weights. When unset, hf-hub's # HuggingFace cache directory for model weights.
# default (~/.cache/huggingface) is used. #
# Resolution order (first hit wins):
# 1. `hf_cache` here in this file.
# 2. `HF_HUB_CACHE` env var — same convention as the Python
# `huggingface_hub` library, so an existing cache directory shared
# with other tooling can be reused without per-tool config.
# 3. `HF_HOME` env var (cache appended as `$HF_HOME/hub`).
# 4. hf-hub's default (`~/.cache/huggingface/hub`).
#
# For per-host overrides (e.g. one neuron has an SSD with prefetched
# weights), prefer a systemd drop-in over editing this file:
# /etc/systemd/system/neuron.service.d/local.conf:
# [Service]
# Environment=HF_HUB_CACHE=/archive/hf-cache
# hf_cache = "/var/lib/neuron/hf-cache" # hf_cache = "/var/lib/neuron/hf-cache"
# -- Default models ---------------------------------------------------------- # -- Default models ----------------------------------------------------------

View File

@@ -6,7 +6,11 @@
# #
# Required defines at rpmbuild time: # Required defines at rpmbuild time:
# cortex_version e.g. "0.1.16" # cortex_version e.g. "0.1.16"
# cortex_prerelease e.g. "0.1.20260518gitabcdef0" (used as Release) # cortex_prerelease e.g. "0.1.20260518140530.gitabcdef0"
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
# commit time (sec) commit sha
# (used as Release; the timestamp prefix
# keeps same-day builds strictly ordered.)
%global _build_id_links none %global _build_id_links none
%global debug_package %{nil} %global debug_package %{nil}

View File

@@ -9,7 +9,11 @@
# neuron_version e.g. "0.1.16" # neuron_version e.g. "0.1.16"
# neuron_flavour e.g. "ada", "blackwell" — matches the CI build # neuron_flavour e.g. "ada", "blackwell" — matches the CI build
# matrix's compute_cap label. # matrix's compute_cap label.
# neuron_prerelease e.g. "0.1.20260518gitabcdef0" (used as Release) # neuron_prerelease e.g. "0.1.20260518140530.gitabcdef0"
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
# commit time (sec) commit sha
# (used as Release; the timestamp prefix
# keeps same-day builds strictly ordered.)
# #
# One flavour can be installed at a time on a given host; flavour # One flavour can be installed at a time on a given host; flavour
# packages Conflict with each other. # packages Conflict with each other.

View File

@@ -71,6 +71,34 @@ ensure_lair_repo() {
fi fi
} }
# Ensure libcudnn.so.9 is resolvable on the remote host so the
# neuron binary (built with --features cudnn) doesn't fail at startup
# with "cannot open shared object file: No such file or directory".
#
# Probes ldconfig first — if cuDNN was installed manually (.tar/.run
# install), it'll be cached by ldconfig and we don't touch it.
# Otherwise adds NVIDIA's RHEL9 CUDA repo (the Fedora 43 CUDA repo
# doesn't ship cuDNN packages — only the RHEL9 one does) and installs
# libcudnn9-cuda-13.
ensure_cudnn_runtime() {
local host="$1"
if ssh "${host}" "ldconfig -p | grep -q libcudnn.so.9" 2>/dev/null; then
return 0
fi
echo "[${host}] installing cuDNN runtime"
if ! ssh "${host}" "test -f /etc/yum.repos.d/cuda-rhel9-x86_64.repo" 2>/dev/null; then
if ! ssh "${host}" sudo dnf config-manager addrepo \
--from-repofile=https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
>/dev/null 2>&1; then
echo "[${host}] WARNING: failed to add rhel9 CUDA repo (proceeding anyway)"
fi
fi
if ! ssh "${host}" sudo dnf install -y libcudnn9-cuda-13 >/dev/null 2>&1; then
echo "[${host}] WARNING: failed to install libcudnn9-cuda-13"
echo "[${host}] neuron may fail to start; install cuDNN manually if so"
fi
}
# True when the named package needs to be installed or upgraded on the # True when the named package needs to be installed or upgraded on the
# remote host — either it's not present, or a newer version exists in # remote host — either it's not present, or a newer version exists in
# the repo. False only when the installed version is current. # the repo. False only when the installed version is current.
@@ -94,6 +122,38 @@ needs_update() {
fi fi
} }
# True if the named package is currently installed on the remote host.
# Used to decide between `dnf install` (fresh) and `dnf upgrade` (stale):
# dnf5's `install` is a no-op when the package is already present at
# any version — it does NOT auto-upgrade to the latest available — so
# the wrong command silently leaves the host on an old build.
is_installed() {
local host="$1" pkg="$2"
ssh "${host}" "rpm -q ${pkg}" >/dev/null 2>&1
}
# Install or upgrade the named package on the remote, picking the
# right dnf verb based on the installed-or-not state. Returns 0 with
# dnf's combined stdout/stderr captured in __DNF_OUTPUT__ on success,
# and 1 with the same captured output on failure.
__DNF_OUTPUT__=""
install_or_upgrade() {
local host="$1" pkg="$2"
local cmd
if is_installed "${host}" "${pkg}"; then
cmd="upgrade"
else
cmd="install"
fi
if __DNF_OUTPUT__=$(
ssh "${host}" sudo dnf "${cmd}" --refresh --allowerasing -y "${pkg}" 2>&1
); then
return 0
else
return 1
fi
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# cortex (gateway) # cortex (gateway)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -108,12 +168,12 @@ if needs_update "${cortex_host}" cortex; then
# under set -e. # under set -e.
if ssh "${cortex_host}" "[ ! -f /usr/lib/systemd/system/cortex.service ] || sudo systemctl stop cortex.service"; then if ssh "${cortex_host}" "[ ! -f /usr/lib/systemd/system/cortex.service ] || sudo systemctl stop cortex.service"; then
echo "[${cortex_host}] stopped cortex service" echo "[${cortex_host}] stopped cortex service"
if dnf_output=$(ssh "${cortex_host}" sudo dnf install --refresh --allowerasing -y cortex 2>&1); then if install_or_upgrade "${cortex_host}" cortex; then
cortex_nvr=$(installed_nvr "${cortex_host}" cortex) cortex_nvr=$(installed_nvr "${cortex_host}" cortex)
echo "[${cortex_host}] installed/upgraded cortex to ${cortex_nvr}" echo "[${cortex_host}] installed/upgraded cortex to ${cortex_nvr}"
else else
echo "[${cortex_host}] failed to install/upgrade cortex:" echo "[${cortex_host}] failed to install/upgrade cortex:"
echo "${dnf_output}" | sed "s/^/[${cortex_host}] /" echo "${__DNF_OUTPUT__}" | sed "s/^/[${cortex_host}] /"
fi fi
else else
echo "[${cortex_host}] failed to stop cortex service" echo "[${cortex_host}] failed to stop cortex service"
@@ -138,6 +198,25 @@ else
echo "[${cortex_host}] failed to sync cortex.toml" echo "[${cortex_host}] failed to sync cortex.toml"
fi fi
# Sync models.toml on the same lifecycle as cortex.toml — operator-owned,
# gitignored, drives /v1/models catalogue × topology resolution.
if [[ -f "${REPO_DIR}/models.toml" ]]; then
if rsync \
--archive \
--compress \
--rsync-path 'sudo rsync' \
--chown root:root \
--chmod 644 \
"${REPO_DIR}/models.toml" \
"${cortex_host}:/etc/cortex/models.toml"; then
echo "[${cortex_host}] sync'd models.toml"
else
echo "[${cortex_host}] failed to sync models.toml"
fi
else
echo "[${cortex_host}] no local models.toml — leaving /etc/cortex/models.toml untouched"
fi
ssh "${cortex_host}" sudo systemctl daemon-reload ssh "${cortex_host}" sudo systemctl daemon-reload
if ssh "${cortex_host}" systemctl is-active --quiet cortex.service; then if ssh "${cortex_host}" systemctl is-active --quiet cortex.service; then
echo "[${cortex_host}] cortex service is active" echo "[${cortex_host}] cortex service is active"
@@ -156,6 +235,7 @@ for entry in "${neuron_entries[@]}"; do
package="helexa-neuron-${neuron_flavour}" package="helexa-neuron-${neuron_flavour}"
ensure_lair_repo "${neuron_host}" ensure_lair_repo "${neuron_host}"
ensure_cudnn_runtime "${neuron_host}"
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}") neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
if needs_update "${neuron_host}" "${package}"; then if needs_update "${neuron_host}" "${package}"; then
echo "[${neuron_host}] ${package} update available (current: ${neuron_nvr})" echo "[${neuron_host}] ${package} update available (current: ${neuron_nvr})"
@@ -165,7 +245,7 @@ for entry in "${neuron_entries[@]}"; do
# bare helexa-neuron or a different flavour without manual # bare helexa-neuron or a different flavour without manual
# intervention. The Conflicts: clauses in the spec ensure # intervention. The Conflicts: clauses in the spec ensure
# only one flavour is ever resident. # only one flavour is ever resident.
if dnf_output=$(ssh "${neuron_host}" sudo dnf install --refresh --allowerasing -y "${package}" 2>&1); then if install_or_upgrade "${neuron_host}" "${package}"; then
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}") neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
echo "[${neuron_host}] installed/upgraded ${package} to ${neuron_nvr}" echo "[${neuron_host}] installed/upgraded ${package} to ${neuron_nvr}"
# Ensure firewalld allows neuron port # Ensure firewalld allows neuron port
@@ -177,7 +257,7 @@ for entry in "${neuron_entries[@]}"; do
fi fi
else else
echo "[${neuron_host}] failed to install ${package}:" echo "[${neuron_host}] failed to install ${package}:"
echo "${dnf_output}" | sed "s/^/[${neuron_host}] /" echo "${__DNF_OUTPUT__}" | sed "s/^/[${neuron_host}] /"
fi fi
else else
echo "[${neuron_host}] failed to stop neuron service" echo "[${neuron_host}] failed to stop neuron service"

60
script/tp-smoke.sh Executable file
View File

@@ -0,0 +1,60 @@
#!/bin/env bash
#
# TP smoke test against a deployed neuron host.
#
# SSHes into the target host and runs `neuron --tp-smoke --tp-size N
# --cuda-devices ...` directly — no HTTP API involved. The smoke
# subcommand spawns N-1 worker subprocesses, joins them in an NCCL
# communicator, runs one AllReduce(Sum) of `1u32` across every rank, and
# verifies the observed sum equals world_size on every rank.
#
# This validates the lower-half of the TP stack (NCCL + IPC topology +
# subprocess lifecycle) without touching model load, inference, or HTTP.
# A failure here means the host cannot run any TP model and there is no
# point debugging the higher layers.
#
# Usage:
# script/tp-smoke.sh [host] [tp_size] [cuda_devices]
#
# Defaults:
# host = beast.hanzalova.internal (only fleet host with 2 GPUs)
# tp_size = 2
# cuda_devices = 0,1
set -euo pipefail
HOST="${1:-beast.hanzalova.internal}"
TP_SIZE="${2:-2}"
CUDA_DEVICES="${3:-0,1}"
say() { printf '[%s] %s\n' "${HOST}" "$*" >&2; }
die() { say "FAIL: $*"; exit 1; }
say "running neuron --tp-smoke --tp-size ${TP_SIZE} --cuda-devices ${CUDA_DEVICES}"
# Run as root via sudo because:
# - cuda contexts under a user account require either the nvidia
# uvm/peer devices to be world-readable or the user to be in a
# priviliged group (neither is true on stock fc43);
# - the installed binary lives at /usr/bin/neuron with no setuid;
# Running through root is the simplest path that matches how
# systemd-managed neuron sees the GPUs in production.
#
# The smoke command is read-only — it allocates a transient NCCL comm
# and a 1u32 buffer per rank, then tears it all down.
if ! ssh -o BatchMode=yes "${HOST}" \
sudo /usr/bin/neuron \
--tp-smoke \
--tp-size "${TP_SIZE}" \
--cuda-devices "${CUDA_DEVICES}" 2>&1 | tee /tmp/tp-smoke-"${HOST}".log
then
die "tp-smoke exited non-zero (see /tmp/tp-smoke-${HOST}.log)"
fi
# Final stdout line is `status=ok` on success.
if grep -q '^status=ok$' /tmp/tp-smoke-"${HOST}".log; then
say "PASS — NCCL handshake + AllReduce sanity check OK across ${TP_SIZE} ranks"
exit 0
else
die "no status=ok line in output"
fi

View File

@@ -9,38 +9,56 @@
# after pushing new neuron builds. # after pushing new neuron builds.
# #
# Usage: # Usage:
# script/validate-neuron.sh [host] [model_id] [quant] # script/validate-neuron.sh [host] [model_id] [quant] [tp_size]
# #
# Defaults: # Defaults:
# host = beast.hanzalova.internal # host = beast.hanzalova.internal
# model_id = unsloth/Qwen3-0.6B-GGUF (official Qwen3-*-GGUF repos # model_id = unsloth/Qwen3-0.6B-GGUF (official Qwen3-*-GGUF repos
# ship Q8_0 only; unsloth's mirror ships the full Q-spectrum # ship Q8_0 only; unsloth's mirror ships the full Q-spectrum
# including Q4_K_M) # including Q4_K_M)
# quant = Q4_K_M # quant = Q4_K_M (empty = dense safetensors path)
# tp_size = unset (= 1 = single-GPU; pass 2 to drive the TP path)
set -euo pipefail set -euo pipefail
HOST="${1:-beast.hanzalova.internal}" HOST="${1:-beast.hanzalova.internal}"
MODEL_ID="${2:-unsloth/Qwen3-0.6B-GGUF}" MODEL_ID="${2:-unsloth/Qwen3-0.6B-GGUF}"
QUANT="${3:-Q4_K_M}" # `${3-Q4_K_M}` (no colon) only uses the default when the arg is
# UNSET — passing an explicit empty string drives the dense path.
QUANT="${3-Q4_K_M}"
# tp_size > 1 forces the dense path (TP requires safetensors) and adds
# `tensor_parallel: N` to the load payload. The harness picks device
# indices 0..N-1 by default; override by passing NEURON_DEVICES="0,1,..."
# in the environment.
TP_SIZE="${4-1}"
PORT="${NEURON_PORT:-13131}" PORT="${NEURON_PORT:-13131}"
BASE="http://${HOST}:${PORT}" BASE="http://${HOST}:${PORT}"
# Reasoning probe — concrete, low-temperature answer that small models # Reasoning probe — concrete, low-temperature answer that small models
# can still get right. "Paris" is a strong signal of basic competence # can still get right. "Paris" is a strong signal of basic competence
# beyond gibberish. # beyond gibberish.
PROBE_PROMPT='What is the capital of France? Respond with the city name only, no punctuation.' PROBE_PROMPT='What is the capital of Georgia (Caucasus)? Respond with the city name only, no punctuation.'
EXPECT_SUBSTR='Paris' EXPECT_SUBSTR='Tbilisi'
MAX_TOKENS=32 # Qwen3 prepends <think>...</think> reasoning before the answer when the
# chat template enables thinking mode, which eats most of a small token
# budget. 256 leaves enough room for thinking + final answer.
MAX_TOKENS=256
# /models/load is synchronous — neuron blocks the response until the # /models/load is synchronous — neuron blocks the response until the
# hf-hub download + GGUF parse + tensor materialisation is done. A # hf-hub download + (GGUF parse or safetensors mmap) + tensor
# fresh 0.6B-Q4_K_M is ~400 MB; on a slow link or cold cache that's # materialisation is done. Small GGUF (0.6B-Q4_K_M, ~400 MB) is
# easily a minute. Pick a generous ceiling. # typically a minute on a warm cache, several on a cold one. A
LOAD_TIMEOUT=600 # Qwen3.6-class dense model is ~54 GB and can easily take an hour to
INFER_TIMEOUT=120 # download cold over a residential link, so default high. Override
# with NEURON_LOAD_TIMEOUT=N (seconds) for smaller targets if you'd
# rather fail fast.
LOAD_TIMEOUT="${NEURON_LOAD_TIMEOUT:-3600}"
INFER_TIMEOUT="${NEURON_INFER_TIMEOUT:-120}"
say() { printf '[%s] %s\n' "${HOST}" "$*"; } # Status messages go to stderr so command substitutions like
# `raw=$(run_probe)` capture only the function's intended return value
# (an HTTP body), not the progress chatter.
say() { printf '[%s] %s\n' "${HOST}" "$*" >&2; }
die() { say "FAIL: $*"; exit 1; } die() { say "FAIL: $*"; exit 1; }
probe_health() { probe_health() {
@@ -49,7 +67,11 @@ probe_health() {
} }
list_loaded_ids() { list_loaded_ids() {
curl --silent --fail "${BASE}/models" | yq -r '.[].id' # The manifest is YAML and uses yq; HTTP responses are JSON and use
# jq directly. pip-yq parses input as YAML by default, which trips
# on JSON content that happens to look like YAML aliases (chatcmpl
# ids, escaped quotes inside `<think>...</think>` blocks, etc.).
curl --silent --fail "${BASE}/models" | jq -r '.[].id'
} }
is_loaded() { is_loaded() {
@@ -57,18 +79,38 @@ is_loaded() {
} }
trigger_load() { trigger_load() {
say "POST /models/load ${MODEL_ID} (quant=${QUANT}, device=[0])" # Build the per-rank CUDA device list as a JSON array. Either
# honour NEURON_DEVICES (`0,1,2`) verbatim or default to
# `[0, 1, ..., tp_size - 1]`.
local devices_json
if [[ -n "${NEURON_DEVICES:-}" ]]; then
devices_json=$(jq -n -c --arg s "${NEURON_DEVICES}" \
'$s | split(",") | map(tonumber)')
else
devices_json=$(jq -n -c --argjson n "${TP_SIZE}" '[range(0; $n)]')
fi
say "POST /models/load ${MODEL_ID} (quant=${QUANT:-<dense>}, tp=${TP_SIZE}, devices=${devices_json})"
say " (synchronous; may take a minute on first run while HF downloads)" say " (synchronous; may take a minute on first run while HF downloads)"
# Build the payload via jq so optional fields are omitted entirely
# when not in use. `tensor_parallel` is dropped when tp_size == 1;
# `quant` is dropped when empty. Both can coexist: tp_size > 1 +
# ISQ quant (q5k/q8_0/etc.) loads safetensors and quantizes the
# per-rank shard at load time. GGUF quants (Q4_K_M) are incompatible
# with TP — but the harness rejects that combination at load time
# rather than here.
local payload local payload
payload=$(cat <<EOF local base
{ base=$(jq -n -c \
"model_id": "${MODEL_ID}", --arg id "${MODEL_ID}" \
"harness": "candle", --argjson devices "${devices_json}" \
"quant": "${QUANT}", '{model_id: $id, harness: "candle", devices: $devices}')
"devices": [0] if [[ -n "${QUANT}" ]]; then
} base=$(echo "${base}" | jq -c --arg q "${QUANT}" '. + {quant: $q}')
EOF fi
) if (( TP_SIZE > 1 )); then
base=$(echo "${base}" | jq -c --argjson tp "${TP_SIZE}" '. + {tensor_parallel: $tp}')
fi
payload="${base}"
# --write-out captures the response code on a separate line so we # --write-out captures the response code on a separate line so we
# can surface a real diagnostic instead of relying on --fail. # can surface a real diagnostic instead of relying on --fail.
local resp http_code body local resp http_code body
@@ -88,7 +130,7 @@ EOF
run_probe() { run_probe() {
say "POST /v1/chat/completions (probe: ${PROBE_PROMPT})" say "POST /v1/chat/completions (probe: ${PROBE_PROMPT})"
local payload local payload
payload=$(yq -n -c \ payload=$(jq -n -c \
--arg model "${MODEL_ID}" \ --arg model "${MODEL_ID}" \
--arg content "${PROBE_PROMPT}" \ --arg content "${PROBE_PROMPT}" \
--argjson tokens "${MAX_TOKENS}" \ --argjson tokens "${MAX_TOKENS}" \
@@ -124,10 +166,15 @@ fi
raw=$(run_probe) raw=$(run_probe)
echo "---" echo "---"
echo "${raw}" | yq -r '.' # Dump the raw JSON. Don't pipe through `yq -r '.'` — yq's default
# YAML output mode chokes on JSON strings that contain `<` (and the
# `<think>` markers Qwen3 emits during reasoning are a perfect
# example). The targeted `yq -r '.path'` calls below work fine
# because jq's path filter mode bypasses the YAML re-emit.
echo "${raw}"
echo "---" echo "---"
content=$(echo "${raw}" | yq -r '.choices[0].message.content // empty') content=$(echo "${raw}" | jq -r '.choices[0].message.content // empty')
if [[ -z "${content}" ]]; then if [[ -z "${content}" ]]; then
die "no content in chat completion response" die "no content in chat completion response"
fi fi