4 Commits

Author SHA1 Message Date
e874c3483d fix(rpm): explicitly Provides user(name) to satisfy systemd unit Requires
Some checks failed
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
CI / Format, lint, build, test (push) Has been cancelled
Diagnosing the persistent "Nothing to do" on v0.1.10 surfaced that
removing %attr(,,name) from %files wasn't enough. systemd-rpm-macros
ships its own rpm dep generator (/usr/lib/rpm/systemd.req) that parses
User=/Group= directives from every .service file the package ships
and emits Requires: user(NAME)/group(NAME) accordingly.

Rpmbuild log from v0.1.10 shows these Requires are still emitted even
after the %attr removal. Meanwhile the sysusers provides-generator
emits group(NAME) in both unversioned and versioned forms, but only
a versioned user(NAME) = <base64> when the u-line has GECOS/home/shell
fields. The asymmetry leaves Requires: user(NAME) unresolvable.

Add explicit Provides: user(NAME) back to both specs, with a comment
documenting the actual cause (systemd unit parsing, not file attrs)
so the next person touching these specs doesn't repeat the mistake.

Why monsoon didn't hit this: it creates its user in %pre via
groupadd/useradd (not sysusers.d), so no Provides are generated at
all — matching the Requires: user(monsoon) by luck of the rpm solver
treating unknown symbols as soft-fails for that path. Ours went through
the sysusers Provides code path and hit the asymmetry instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 15:30:55 +03:00
2caaae018a ci: migrate rpm changelog generation to reusable action
Replace the local .gitea/scripts/generate-rpm-changelog.sh with the
shared composite action at https://git.lair.cafe/actions/rpm-changelog@v1.
Behaviour is identical — collect commits since the previous v* tag,
filter bump-version and merge noise, prepend a dated entry to the
spec — but the logic now lives in one place that other projects can
consume.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 15:23:45 +03:00
18d00001cf ci: auto-generate rpm changelog entry per release
On every tag push, build a %changelog entry from the git log since
the previous v* tag and prepend it to each spec. Stops the initial
entry from drifting further and catches bogus-date / stale-version
warnings automatically since the generated date always matches the
day the CI runs.

The generator drops "chore: bump version" commits (bot-authored,
noisy in user-facing changelogs) and merge commits. Author defaults
to the gitea-actions identity but can be overridden via
CHANGELOG_AUTHOR env var if a human release is desired.

Requires fetch-depth: 0 on checkout so git describe can see prior
tags and git log can reach them.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 15:04:36 +03:00
ad1442c096 fix(rpm): correct weekday in changelog entry
April 15 2026 was a Wednesday, not Tuesday. rpmbuild validates the
day-of-week against the date and warns on mismatch.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 14:58:40 +03:00
77 changed files with 611 additions and 14318 deletions

View File

@@ -1,342 +0,0 @@
name: build-prerelease
# Manually-dispatched workflow that builds CUDA-flavoured neuron binaries
# (and a single cortex binary), packages each as a Fedora RPM, signs
# them, and publishes to the `unstable` channel at rpm.lair.cafe.
#
# Trigger from the Gitea UI: Actions → build-prerelease → Run workflow.
# Optionally provide a `ref` to build from a non-default branch.
#
# The published packages are versioned as e.g.
# helexa-neuron-blackwell-0.1.16-0.1.20260518T140530.gitabcdef0.fc43.x86_64
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
# commit time (s) commit sha
# so they sort BELOW the eventual 0.1.16-1 stable release, and so two
# commits on the same day are still strictly ordered by their commit
# timestamps (rather than by RPM-vercmp's alpha-vs-digit precedence
# on the SHA fragment).
on:
# Auto-build on every push to main so the unstable channel tracks
# head without a manual dispatch step.
push:
branches: [main]
# Manual dispatch still available to build from a non-main ref.
workflow_dispatch:
inputs:
ref:
description: "Git ref to build (branch / tag / commit). Defaults to the workflow's branch."
required: false
default: ""
concurrency:
# Share the group with ci.yml so the two workflows can't run
# concurrently on the same `rust` runner (act reuses the workspace
# cache and races destroy each other's build files mid-compile).
# cancel-in-progress=false → workflows queue; if a newer push lands,
# the older run is still picked up by ci.yml's own ref-keyed
# concurrency (same group, queued).
group: cortex-runner-pool-${{ github.ref }}
cancel-in-progress: false
env:
CARGO_INCREMENTAL: "0"
jobs:
prepare:
name: Resolve version stamps
runs-on: rust
outputs:
version: ${{ steps.info.outputs.version }}
release: ${{ steps.info.outputs.release }}
short_sha: ${{ steps.info.outputs.short_sha }}
commit_timestamp: ${{ steps.info.outputs.commit_timestamp }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
fetch-depth: 0
- id: info
run: |
set -eux
VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml)
SHORT_SHA=$(git rev-parse --short=7 HEAD)
# Second-precise commit timestamp gives the release stamp a
# strictly monotonic numeric prefix. The earlier %Y%m%d-only
# form let same-day builds be ordered by RPM's rpmvercmp
# rules over the SHA, which is non-chronological — e.g.
# "git602e8e1" sorts newer than "gitf9f5fa4" purely because
# rpmvercmp ranks digit-prefixed segments above alpha ones.
# The SHA stays only as a debug identifier; sort order is
# decided entirely by the timestamp.
COMMIT_TIMESTAMP=$(git log -1 --format=%cd --date=format:%Y%m%d%H%M%S HEAD)
RELEASE="0.1.${COMMIT_TIMESTAMP}.git${SHORT_SHA}"
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
echo "release=${RELEASE}" >> "$GITHUB_OUTPUT"
echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT"
echo "commit_timestamp=${COMMIT_TIMESTAMP}" >> "$GITHUB_OUTPUT"
build-cortex:
name: Build cortex binary
needs: prepare
# runner-rust image already provides rust/cargo/clippy/rustfmt via
# dnf — no rustup install step needed.
runs-on: rust
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Build cortex (release)
run: cargo build --release -p cortex-cli
- name: Stage binary
run: |
mkdir --parents artifacts
cp target/release/cortex artifacts/cortex
./artifacts/cortex --version || true
- uses: actions/upload-artifact@v3
with:
name: cortex-fc43
path: artifacts/cortex
retention-days: 1
build-neuron:
name: Build neuron-${{ matrix.flavour }}
needs: prepare
strategy:
fail-fast: false
matrix:
include:
- flavour: ampere
compute_cap: "86"
runner: cuda-13.0
cuda_home: /usr/local/cuda-13.0
build_jobs: 8
nvcc_threads: 4
cargo_features: "cuda cudnn flash-attn"
- flavour: ada
compute_cap: "89"
runner: cuda-13.0
cuda_home: /usr/local/cuda-13.0
build_jobs: 8
nvcc_threads: 4
cargo_features: "cuda cudnn flash-attn"
- flavour: blackwell
compute_cap: "120"
runner: cuda-13.0
cuda_home: /usr/local/cuda-13.0
build_jobs: 8
nvcc_threads: 4
cargo_features: "cuda cudnn flash-attn"
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Build neuron with CUDA (${{ matrix.flavour }})
run: |
set -eux
export PATH="${{ matrix.cuda_home }}/bin:${PATH}"
export LD_LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LD_LIBRARY_PATH:-}"
export LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LIBRARY_PATH:-}"
cargo build --release -p neuron --features "${{ matrix.cargo_features }}"
env:
CUDA_COMPUTE_CAP: ${{ matrix.compute_cap }}
CARGO_BUILD_JOBS: ${{ matrix.build_jobs }}
NVCC_THREADS: ${{ matrix.nvcc_threads }}
- name: Stage binary
run: |
mkdir --parents artifacts
cp target/release/neuron artifacts/neuron-${{ matrix.flavour }}
file "artifacts/neuron-${{ matrix.flavour }}"
- uses: actions/upload-artifact@v3
with:
name: neuron-${{ matrix.flavour }}-fc43
path: artifacts/neuron-${{ matrix.flavour }}
retention-days: 1
package-cortex:
name: Package cortex RPM
needs: [prepare, build-cortex]
runs-on: rpm
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- uses: actions/download-artifact@v3
with:
name: cortex-fc43
path: artifacts/
- name: Build RPM
run: |
set -eux
rm -f ~/.rpmmacros
rpmdev-setuptree
cp artifacts/cortex ~/rpmbuild/SOURCES/
cp data/cortex.service ~/rpmbuild/SOURCES/
cp data/cortex-sysusers.conf ~/rpmbuild/SOURCES/
cp data/cortex-firewalld.xml ~/rpmbuild/SOURCES/
cp cortex.example.toml ~/rpmbuild/SOURCES/
cp models.example.toml ~/rpmbuild/SOURCES/
cp LICENSE ~/rpmbuild/SOURCES/
rpmbuild -bb rpm/cortex-prerelease.spec \
--define "cortex_version ${{ needs.prepare.outputs.version }}" \
--define "cortex_prerelease ${{ needs.prepare.outputs.release }}" \
--undefine dist \
--define "dist .fc43"
- uses: actions/upload-artifact@v3
with:
name: rpm-cortex-fc43
path: ~/rpmbuild/RPMS/x86_64/*.rpm
retention-days: 7
package-neuron:
name: Package helexa-neuron-${{ matrix.flavour }} RPM
needs: [prepare, build-neuron]
runs-on: rpm
strategy:
fail-fast: false
matrix:
include:
- flavour: ampere
- flavour: ada
- flavour: blackwell
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- uses: actions/download-artifact@v3
with:
name: neuron-${{ matrix.flavour }}-fc43
path: artifacts/
- name: Build RPM
run: |
set -eux
rm -f ~/.rpmmacros
rpmdev-setuptree
cp artifacts/neuron-${{ matrix.flavour }} ~/rpmbuild/SOURCES/
cp data/neuron.service ~/rpmbuild/SOURCES/
cp data/neuron-sysusers.conf ~/rpmbuild/SOURCES/
cp data/neuron-firewalld.xml ~/rpmbuild/SOURCES/
cp neuron.example.toml ~/rpmbuild/SOURCES/
cp LICENSE ~/rpmbuild/SOURCES/
rpmbuild -bb rpm/helexa-neuron-prerelease.spec \
--define "neuron_version ${{ needs.prepare.outputs.version }}" \
--define "neuron_flavour ${{ matrix.flavour }}" \
--define "neuron_prerelease ${{ needs.prepare.outputs.release }}" \
--undefine dist \
--define "dist .fc43"
- uses: actions/upload-artifact@v3
with:
name: rpm-neuron-${{ matrix.flavour }}-fc43
path: ~/rpmbuild/RPMS/x86_64/*.rpm
retention-days: 7
publish:
name: Publish to rpm.lair.cafe (unstable)
needs: [package-cortex, package-neuron]
runs-on: rpm
concurrency:
group: rpm-publish
cancel-in-progress: false
env:
RPM_REPO_HOST: oolon.kosherinata.internal
FEDORA_VERSION: "43"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Download all built RPMs
uses: actions/download-artifact@v3
with:
path: rpms/
pattern: rpm-*-fc43
- name: Flatten RPM artifacts
run: |
set -eux
find rpms/ -name '*.rpm' -exec mv --target-directory=rpms/ {} +
find rpms/ -mindepth 1 -type d -empty -delete
ls -la rpms/
- name: Check for sequoia-sq
run: |
if ! command -v sq &> /dev/null; then
echo "ERROR: sequoia-sq is not installed. Install with: sudo dnf install sequoia-sq"
exit 1
fi
- name: Import signing key
env:
# Pass secrets via env so values stay out of the rendered shell
# script (which Gitea includes in step logs). Template
# expansion of ${{ secrets.X }} inside `run:` writes the literal
# value into the script and depends on Gitea's log masker to
# scrub it — fragile for multi-line keys.
RPM_SIGNING_KEY: ${{ secrets.RPM_SIGNING_KEY }}
RPM_SIGNING_KEY_ID: ${{ secrets.RPM_SIGNING_KEY_ID }}
run: |
echo "$RPM_SIGNING_KEY" | gpg --batch --import
fpr=$(gpg --batch --with-colons --list-keys "$RPM_SIGNING_KEY_ID" | awk -F: '/^fpr:/ { print $10; exit }')
echo "${fpr}:6:" | gpg --batch --import-ownertrust
sed "s/@GPG_NAME@/$RPM_SIGNING_KEY_ID/" rpm/rpmmacros > ~/.rpmmacros
- name: Sign RPMs
run: |
set -eux
for rpm in rpms/*.rpm; do
echo "signing ${rpm}..."
rpm --addsign "${rpm}"
done
- name: Set up SSH for rsync
run: |
install --directory --mode 700 ~/.ssh
echo "${RSYNC_SSH_KEY}" | install --mode 600 /dev/stdin ~/.ssh/id_ed25519
env:
RSYNC_SSH_KEY: ${{ secrets.RSYNC_SSH_KEY }}
- name: Test SSH connectivity
run: |
ssh -o StrictHostKeyChecking=accept-new "gitea_ci@${RPM_REPO_HOST}" exit
- name: Ensure unstable repo directory exists
run: |
ssh "gitea_ci@${RPM_REPO_HOST}" \
"mkdir --parents /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable"
- name: Sync RPMs to unstable repo
run: |
rsync \
--archive \
--verbose \
--chmod D755,F644 \
rpms/*.rpm \
"gitea_ci@${RPM_REPO_HOST}:/var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/"
- name: Update unstable repo metadata
run: |
ssh "gitea_ci@${RPM_REPO_HOST}" \
"cd /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable && createrepo_c --update ."
- name: Generate packages.json manifest
run: |
scp script/generate-packages-json.py "gitea_ci@${RPM_REPO_HOST}:/tmp/"
ssh "gitea_ci@${RPM_REPO_HOST}" \
"python3 /tmp/generate-packages-json.py \
--repodata-dir /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/repodata \
--output /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/packages.json \
--base-url https://rpm.lair.cafe/fedora/${FEDORA_VERSION}/x86_64/unstable"

View File

@@ -7,16 +7,6 @@ on:
pull_request:
branches: [main]
# Share a concurrency group with build-prerelease.yml so the two
# workflows don't race on the same `rust` runner workspace (act's
# /root/.cache/act/<hash>/hostexecutor/ is shared across concurrent
# jobs and one job's checkout step nukes another's in-flight build
# files). cancel-in-progress=false → they queue; same-ref pushes
# coalesce per workflow via cancel-in-progress on each.
concurrency:
group: cortex-runner-pool-${{ github.ref }}
cancel-in-progress: false
env:
CARGO_INCREMENTAL: "0"
RUSTC_WRAPPER: sccache
@@ -26,42 +16,53 @@ env:
SCCACHE_S3_USE_SSL: "false"
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
# fmt, clippy, and test all run in parallel on the same `rust` runner
# and would otherwise share /root/.cache/act/<hash>/hostexecutor/target/,
# racing each other's cargo temp files (.tmpXXXXXX) and failing builds
# mid-compile. Give each job its own target directory so the invocations
# don't collide. sccache still backs the actual rustc cache, so the
# rebuild penalty is small.
CARGO_TARGET_DIR: target-${{ github.job }}
jobs:
fmt:
name: Format
runs-on: rust
check:
name: Format, lint, build, test
runs-on: fedora
steps:
- uses: actions/checkout@v4
- run: cargo fmt --check --all
clippy:
name: Clippy
runs-on: rust
steps:
- uses: actions/checkout@v4
- run: cargo clippy --workspace -- -D warnings
- run: sccache --show-stats
- name: Cache cargo registry and target
uses: actions/cache@v4
with:
path: |
~/.cargo/bin
~/.cargo/registry/index
~/.cargo/registry/cache
~/.cargo/git/db
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
test:
name: Test
runs-on: rust
steps:
- uses: actions/checkout@v4
- run: cargo test --workspace
- run: sccache --show-stats
- name: Ensure sccache with S3 support
env:
RUSTC_WRAPPER: ""
run: |
if sccache --version 2>/dev/null && sccache --show-stats 2>/dev/null; then
echo "sccache with S3 support already installed"
else
cargo install sccache --features s3 --locked
fi
- name: Check formatting
run: cargo fmt --check --all
- name: Clippy
run: cargo clippy --workspace -- -D warnings
- name: Test
run: cargo test --workspace
- name: Show sccache stats
run: sccache --show-stats
srpm-cortex:
name: Build cortex SRPM
runs-on: rpm
needs: [fmt, clippy, test]
runs-on: fedora
needs: check
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v4
@@ -120,8 +121,8 @@ jobs:
srpm-neuron:
name: Build neuron SRPM
runs-on: rpm
needs: [fmt, clippy, test]
runs-on: fedora
needs: check
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v4
@@ -138,37 +139,37 @@ jobs:
run: |
VERSION="${{ steps.version.outputs.VERSION }}"
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
sed -i "s/^Version:.*/Version: ${VERSION}/" neuron.spec
- name: Generate changelog entry
uses: https://git.lair.cafe/actions/rpm-changelog@v1
with:
spec: helexa-neuron.spec
spec: neuron.spec
version: ${{ steps.version.outputs.VERSION }}
- name: Generate source tarball
run: |
set -ex
VERSION="${{ steps.version.outputs.VERSION }}"
tar czf /tmp/helexa-neuron-${VERSION}.tar.gz \
--transform "s,^\.,helexa-neuron-${VERSION}," \
tar czf /tmp/neuron-${VERSION}.tar.gz \
--transform "s,^\.,neuron-${VERSION}," \
--exclude='./target' \
--exclude='./.git' \
--exclude='*.tar.gz' \
--exclude='*.src.rpm' \
.
mv /tmp/helexa-neuron-${VERSION}.tar.gz .
mv /tmp/neuron-${VERSION}.tar.gz .
- name: Vendor Rust dependencies
run: |
VERSION="${{ steps.version.outputs.VERSION }}"
cargo vendor vendor/
tar czf helexa-neuron-${VERSION}-vendor.tar.gz vendor/
tar czf neuron-${VERSION}-vendor.tar.gz vendor/
rm -rf vendor/
- name: Build SRPM
run: |
rpmbuild -bs helexa-neuron.spec \
rpmbuild -bs neuron.spec \
--define "_sourcedir $(pwd)" \
--define "_srcrpmdir $(pwd)"
@@ -180,7 +181,7 @@ jobs:
copr-cortex:
name: Publish cortex to COPR
runs-on: fedora-43
runs-on: fedora
needs: srpm-cortex
steps:
- name: Download SRPM
@@ -191,13 +192,13 @@ jobs:
- name: Publish to COPR
uses: https://git.lair.cafe/actions/copr-publish@v1
with:
project: helexa/helexa
project: helexa/cortex
srpm: "*.src.rpm"
copr-config: ${{ secrets.COPR_CONFIG }}
copr-neuron:
name: Publish neuron to COPR
runs-on: fedora-43
runs-on: fedora
needs: srpm-neuron
steps:
- name: Download SRPM
@@ -208,53 +209,31 @@ jobs:
- name: Publish to COPR
uses: https://git.lair.cafe/actions/copr-publish@v1
with:
project: helexa/helexa
project: helexa/neuron
srpm: "*.src.rpm"
copr-config: ${{ secrets.COPR_CONFIG }}
bump-version:
name: Bump version in source
runs-on: rust
runs-on: fedora
needs: [copr-cortex, copr-neuron]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Determine version
id: version
run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> "$GITHUB_OUTPUT"
- name: Stamp version
run: |
VERSION="${{ steps.version.outputs.VERSION }}"
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
cargo check --workspace 2>/dev/null || true
- name: Generate cortex changelog entry
uses: https://git.lair.cafe/actions/rpm-changelog@v1
with:
spec: cortex.spec
version: ${{ steps.version.outputs.VERSION }}
- name: Generate helexa-neuron changelog entry
uses: https://git.lair.cafe/actions/rpm-changelog@v1
with:
spec: helexa-neuron.spec
version: ${{ steps.version.outputs.VERSION }}
- name: Commit and push
- name: Stamp version and push
env:
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
run: |
VERSION="${{ steps.version.outputs.VERSION }}"
VERSION="${GITHUB_REF#refs/tags/v}"
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
sed -i "s/^Version:.*/Version: ${VERSION}/" neuron.spec
cargo check --workspace 2>/dev/null || true
git config user.name "Gitea Actions"
git config user.email "actions@git.lair.cafe"
git add Cargo.toml Cargo.lock cortex.spec helexa-neuron.spec
git add Cargo.toml Cargo.lock cortex.spec neuron.spec
if git diff --cached --quiet; then
echo "Nothing to commit for ${VERSION}"
echo "Version already at ${VERSION}"
else
git commit -m "chore: bump version to ${VERSION}"
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/helexa/cortex.git"

2
.gitignore vendored
View File

@@ -4,6 +4,4 @@
.idea/
.vscode/
cortex.toml
models.toml
doc/plan/*
/target-cuda/

116
CLAUDE.md
View File

@@ -125,8 +125,7 @@ automatically. Clippy warnings must be resolved, not suppressed with
- One or more GPU nodes running mistral.rs on port 8080
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
- Each node runs `mistralrs serve` on port 8080
- Gateway listens on port 31313 (API) and 31314 (metrics)
- neuron listens on port 13131 on each GPU host
- Gateway listens on port 8000 (API) and 9100 (metrics)
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
## Conventions
@@ -381,7 +380,7 @@ processes (one process per loaded model, each on its own port).
## neuron API
neuron exposes an HTTP API on port 13131 that cortex polls and calls.
neuron exposes an HTTP API on port 9090 that cortex polls and calls.
```
GET /discovery
@@ -425,8 +424,8 @@ endpoint. cortex.toml shrinks to:
```toml
[gateway]
listen = "0.0.0.0:31313"
metrics_listen = "0.0.0.0:31314"
listen = "0.0.0.0:8000"
metrics_listen = "0.0.0.0:9100"
[eviction]
strategy = "lru"
@@ -434,15 +433,15 @@ defrag_after_cycles = 50
[[neurons]]
name = "beast"
endpoint = "http://beast.hanzalova.internal:13131"
endpoint = "http://beast.hanzalova.internal:9090"
[[neurons]]
name = "benjy"
endpoint = "http://benjy.hanzalova.internal:13131"
endpoint = "http://benjy.kosherinata.internal:9090"
[[neurons]]
name = "quadbrat"
endpoint = "http://quadbrat.hanzalova.internal:13131"
endpoint = "http://quadbrat.hanzalova.internal:9090"
```
On startup and periodically, cortex calls `GET /discovery` and
@@ -522,7 +521,7 @@ cortex/
│ │ └── metrics.rs # prometheus exporter (unchanged)
│ ├── neuron/ # node plane (replaces cortex-agent)
│ │ └── src/
│ │ ├── main.rs # binary entrypoint, axum server on :13131
│ │ ├── main.rs # binary entrypoint, axum server on :9090
│ │ ├── discovery.rs # nvidia-smi, device enumeration
│ │ ├── health.rs # runtime GPU polling
│ │ ├── api.rs # HTTP handlers for /discovery, /models, etc.
@@ -596,65 +595,70 @@ placement matching can be added incrementally.
Completed. Both packages have RPM specs, systemd units, and example configs.
CI builds parallel SRPMs on tag push and publishes to separate COPR repos.
- `cortex.spec` — installs the `cortex` binary. Package name keeps the
short `cortex` because no Fedora package collides with it.
- `helexa-neuron.spec` — installs the `neuron` binary under package name
`helexa-neuron`. Renamed from bare `neuron` to avoid collision with
Fedora's NEURON neural-simulation package
(https://src.fedoraproject.org/rpms/neuron); binary, systemd unit,
system user, and config dir all stay named `neuron` since those are
project-local contexts.
- `cortex.spec` `helexa/cortex` COPR: binary, systemd unit, config files
- `neuron.spec``helexa/neuron` COPR: binary, systemd unit, config
- `data/cortex.service`, `data/neuron.service` — systemd units
- `cortex.example.toml`, `neuron.example.toml`, `models.example.toml`
- CI: parallel `srpm-cortex` + `srpm-neuron` jobs, then parallel COPR
publish to a single project `helexa/helexa` hosting both packages.
- CI: parallel `srpm-cortex` + `srpm-neuron` jobs, then parallel COPR publish
Install:
```sh
dnf copr enable helexa/helexa
dnf install cortex # gateway host
dnf install helexa-neuron # GPU nodes
dnf copr enable helexa/cortex && dnf install cortex # gateway host
dnf copr enable helexa/neuron && dnf install neuron # GPU nodes
```
## 2026-05-18 addendum: candle-native pivot
### Phase 11: llama.cpp harness stub
Phases 11 (llama.cpp harness) and 12 (mistral.rs COPR) below are
**superseded**. The project no longer treats mistral.rs or llama.cpp as
dependencies — both are conceptually out of scope. neuron becomes a
candle-native inference daemon, with `Harness` retained as an
internal seam for adding future engines (vision/audio/diffusion) but
its only implementation being in-process candle.
**Goal:** Prove the harness abstraction works with a second engine.
The full staged plan for this pivot lives at
`~/.claude/plans/create-a-more-aggressive-calm-naur.md`. Summary:
**Steps:**
1. `crates/neuron/src/harness/llamacpp.rs` — implement the `Harness`
trait for llama.cpp's `llama-server`.
- `start()` — launch `llama-server` with the correct model path,
`--port`, `--n-gpu-layers`, `--tensor-split` args. Track the
child process.
- `stop()` — send SIGTERM to the child process.
- `list_models()` — llama-server serves one model per process, so
return a single-element list.
- `load_model()` — start a new llama-server process for this model.
- `unload_model()` — stop the process.
- `inference_endpoint()` — return `http://localhost:{assigned_port}`.
2. Port allocation: neuron assigns ports from a range (e.g. 8100-8199)
to llama-server instances.
3. Register in `HarnessRegistry` when configured:
```toml
[[harnesses]]
name = "llamacpp"
binary = "/usr/local/bin/llama-server"
port_range = [8100, 8199]
```
4. Tests: mock llama-server (simple HTTP server returning canned
responses), test load/unload/endpoint lifecycle.
- **Stage 1 (this commit):** delete `mistralrs.rs` and `llamacpp.rs`,
scaffold inert `CandleHarness`, drop `endpoint`/`systemd_unit` from
`HarnessConfig`, default no-op `start`/`stop` on the `Harness` trait.
- **Stages 24:** wire up candle model load/unload (quantized Qwen3
first), add OpenAI-compatible inference endpoint in neuron, then SSE
streaming.
- **Stages 56:** load-on-activation (default models in config) and
unload-on-deactivation (graceful shutdown).
- **Stages 78:** multi-GPU tensor parallelism and broader model/quant
coverage.
**Done when:** A model with `harness = "llamacpp"` in `models.toml` can
be loaded and served through cortex. Tests pass with mock llama-server.
Sections of this document that describe mistral.rs HTTP behaviour
("mistral.rs API gotchas") are retained as historical context for
Phases 110 — they document what was true while the project depended
on mistral.rs. They do not describe current behaviour.
### Phase 12 (lower priority): mistral.rs COPR packaging
---
**Goal:** Fedora RPMs for mistral.rs built against specific CUDA versions.
### Phase 11 (superseded): llama.cpp harness stub
**Steps:**
1. `mistralrs-cuda.spec` — RPM spec that clones a pinned mistral.rs git
tag, builds with `--features cuda`, links against the system CUDA
toolkit. Produces `mistralrs-cuda13-server` (CUDA 13.x / sm_120) and
`mistralrs-cuda12-server` (CUDA 12.x / sm_89). Install binary to
`/usr/local/bin/mistralrs`.
2. COPR build config: enable the NVIDIA CUDA repo as a build dependency.
Pin the CUDA toolkit version in `BuildRequires`.
3. Gitea Actions or manual workflow: bump the mistral.rs tag in the spec,
trigger COPR rebuild.
4. neuron's mistralrs harness config references which binary/package
provides the mistral.rs binary. neuron could warn at startup if the
installed mistral.rs CUDA version doesn't match the discovered driver.
~~Originally planned as a second engine to prove the harness
abstraction.~~ Replaced by the candle harness work in the 2026-05-18
addendum above. llama.cpp's any-model/any-hardware breadth is no
longer in scope for helexa.
**Done when:** `dnf install mistralrs-cuda13-server` on beast provides a
working `mistralrs` binary built for Blackwell GPUs. `dnf install
mistralrs-cuda12-server` on benjy provides one built for Ada GPUs.
### Phase 12 (superseded): mistral.rs COPR packaging
~~Originally planned to ship CUDA-versioned mistral.rs RPMs.~~ Replaced
by the candle harness work in the 2026-05-18 addendum above. With
mistral.rs out of the dependency tree, there is nothing to package.
This is a separate repo/spec — not part of the cortex workspace — but
tightly coupled operationally. Track it as a sibling project.

1616
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@ members = [
]
[workspace.package]
version = "0.1.16"
version = "0.1.8"
edition = "2024"
license = "GPL-3.0-or-later"
repository = "https://git.lair.cafe/helexa/cortex"
@@ -27,7 +27,7 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
# http client (for proxying to neuron backends)
# http client (for proxying to mistralrs backends)
reqwest = { version = "0.12", features = ["json", "stream"] }
# observability

105
README.md
View File

@@ -1,23 +1,22 @@
# cortex
A Rust reverse-proxy and fleet management layer for multi-node GPU inference
clusters. Cortex sits in front of one or more `neuron` daemons (each running
candle-based inference on a local GPU host) and presents a unified OpenAI +
Anthropic compatible API surface.
A Rust reverse-proxy and fleet management layer for multi-node
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
## Problem
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
model affinities) requires a unified API surface that:
- Presents a **single `/v1/models` catalogue** merging every model that can be
served by any neuron in the fleet.
- **Routes requests** to the correct node based on where a model is loaded
(or can be loaded), handling cold-load and eviction transparently.
- Manages **model lifecycle** load on demand, unload cold models, pin
critical ones — by calling each neuron's `/models/{load,unload}` API.
- Presents a **single `/v1/models` catalogue** merging every model across every
node.
- **Routes requests** to the correct node based on where a model is loaded (or
*can* be loaded).
- Manages **model lifecycle** — unload cold models, reload on demand, pin
critical ones — using the mistral.rs
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
- Translates between **OpenAI and Anthropic** request/response envelopes so
every client speaks whichever dialect it prefers.
every client in the homelab speaks whichever dialect it prefers.
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
them as Prometheus counters/histograms.
@@ -31,17 +30,18 @@ model affinities) requires a unified API surface that:
└────────────────┴──────┬───────┴───────────────┘
┌──────────▼──────────┐
cortex
│ (cortex-gateway) │
│ cortex │
(cortex-gateway)
│ │
│ Router · Metrics │
│ Evictor · Translate│
└──┬──────┬────────┬──┘
│ │ │
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
neuron │ │ neuron │ │ neuron
:13131 │ │ :13131 │ │ :13131
candle │ │ candle │ │ candle
gpu-large │ │gpu-med │ │ gpu-small
mistralrs │ │mistral │ │ mistralrs
serve │ │rs serve│ │ serve
│ :8080 │ │ :8080 │ │ :8080 │
└───────────┘ └────────┘ └───────────┘
private network (.internal)
```
@@ -50,48 +50,70 @@ model affinities) requires a unified API surface that:
| Crate | Purpose |
|---|---|
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic envelopes, harness trait, discovery types |
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, poller, metrics exporter |
| `neuron` | Per-node daemon: GPU discovery, in-process candle inference, model lifecycle API |
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
## Node setup
Each GPU node runs `neuron` (listening on `:13131`). Neuron uses
huggingface/candle for in-process inference — there is no external
inference subprocess to manage.
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
declared but start **unloaded** — mistral.rs lazy-loads on first request and
the gateway can explicitly unload/reload via the HTTP API.
The neuron RPM (`helexa-neuron`) ships a systemd unit:
Example node systemd unit:
```sh
dnf copr enable helexa/helexa
dnf install helexa-neuron
systemctl enable --now neuron
```ini
# /etc/systemd/system/mistralrs.service
[Unit]
Description=mistral.rs inference server
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
ExecStart=/usr/local/bin/mistralrs serve \
--from-config /etc/mistralrs/config.toml \
--port 8080
Restart=on-failure
RestartSec=5
Environment=CUDA_VISIBLE_DEVICES=0,1
[Install]
WantedBy=multi-user.target
```
## Gateway config
```toml
# /etc/cortex/cortex.toml
# cortex.toml
[gateway]
listen = "0.0.0.0:31313"
metrics_listen = "0.0.0.0:31314"
listen = "0.0.0.0:8000"
metrics_listen = "0.0.0.0:9100"
[eviction]
strategy = "lru" # lru | priority
defrag_after_cycles = 50
[[neurons]]
name = "beast"
endpoint = "http://beast.internal:13131"
[[nodes]]
name = "gpu-large"
endpoint = "http://gpu-large.internal:8080"
vram_mb = 49_152 # e.g. 2x RTX 4090
pinned = ["your-org/large-model"]
[[neurons]]
name = "benjy"
endpoint = "http://benjy.internal:13131"
[[nodes]]
name = "gpu-medium"
endpoint = "http://gpu-medium.internal:8080"
vram_mb = 24_576 # e.g. RTX 4090
pinned = ["your-org/medium-model"]
[[nodes]]
name = "gpu-small"
endpoint = "http://gpu-small.internal:8080"
vram_mb = 12_288 # e.g. RTX 3060
pinned = ["your-org/embedding-model"]
```
Model placement profiles live in `models.toml` — see `models.example.toml`.
## Building
```sh
@@ -109,20 +131,19 @@ cargo clippy --workspace -- -D warnings # warnings are errors
cargo test --workspace # all tests must pass
```
Tagged releases (`v*`) additionally build SRPMs for both `cortex` and
`helexa-neuron` and publish to COPR.
Tagged releases (`v*`) additionally build an SRPM and publish to COPR.
## Running
```sh
# start the gateway
cortex serve --config /etc/cortex/cortex.toml
cortex serve --config cortex.toml
# check fleet status
cortex status
# list all models across nodes
curl http://localhost:31313/v1/models
curl http://localhost:8000/v1/models
```
## License

View File

@@ -1,30 +0,0 @@
# Helexa fleet manifest.
#
# Drives rolling deploys via script/deploy.sh and serves as the source
# of truth for which hosts run cortex vs neuron, and which CUDA
# compute-capability flavour each neuron host needs.
#
# Flavour ↔ NVIDIA generation ↔ compute cap:
# ampere sm_86 (RTX 30 series — e.g. 3060)
# ada sm_89 (RTX 40 series — e.g. 4090)
# blackwell sm_120 (RTX 50 series — e.g. 5090)
#
# The flavour determines which RPM is installed on a given neuron host:
# helexa-neuron-<flavour>. Only one flavour may be installed at a time
# (the packages Conflict: with each other).
cortex:
host: hanzalova.internal
neurons:
- host: beast.hanzalova.internal
flavour: blackwell
gpu: "2x RTX 5090"
- host: benjy.hanzalova.internal
flavour: ada
gpu: "RTX 4090"
- host: quadbrat.hanzalova.internal
flavour: ampere
gpu: "RTX 3060"

View File

@@ -3,22 +3,22 @@
# Copy to cortex.toml and adjust for your environment.
#
# Environment variable overrides use CORTEX_ prefix with __ separators:
# CORTEX_GATEWAY__LISTEN=0.0.0.0:31313
# CORTEX_GATEWAY__LISTEN=0.0.0.0:9000
[gateway]
listen = "0.0.0.0:31313"
metrics_listen = "0.0.0.0:31314"
listen = "0.0.0.0:8000"
metrics_listen = "0.0.0.0:9100"
[eviction]
strategy = "lru"
# Restart neurons after this many load/unload cycles to defragment VRAM.
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
# Set to 0 to disable.
defrag_after_cycles = 50
# -- Nodes ---------------------------------------------------------------
# Each [[nodes]] entry declares a neuron daemon in the fleet.
# Models are discovered by polling the neuron's /models endpoint.
# Pinned models (see models.toml) are never evicted.
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
# Models are discovered by polling the node's /v1/models endpoint.
# Pinned models are never evicted.
[[nodes]]
name = "gpu-large"

View File

@@ -1,5 +1,5 @@
Name: cortex
Version: 0.1.16
Version: 0.1.8
Release: 1%{?dist}
Summary: Inference gateway for multi-node GPU clusters
@@ -21,7 +21,6 @@ BuildRequires: systemd-rpm-macros
Requires(pre): shadow-utils
Requires: systemd
Requires: firewalld-filesystem
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
# from our .service file and emits Requires: user(cortex)/group(cortex).
@@ -57,7 +56,6 @@ cargo build --release -p cortex-cli
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
install -Dm644 data/cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
install -dm755 %{buildroot}%{_sysconfdir}/cortex
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
@@ -74,53 +72,16 @@ install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
%postun
%systemd_postun_with_restart cortex.service
%posttrans
# Migration: older cortex packages shipped the firewalld service as
# `helexa-cortex` and (in some build streams) with wrong port numbers
# (9301/9302/9304). Operators who enabled that legacy service in their
# zone end up with the wrong-port override taking precedence over the
# vendor `cortex.xml` now in /usr/lib/firewalld/services/. Clean up the
# stale /etc/ override here and migrate any zone bindings to the new
# service name.
if [ -f /etc/firewalld/services/helexa-cortex.xml ]; then
rm -f /etc/firewalld/services/helexa-cortex.xml
fi
if [ -x /usr/bin/firewall-cmd ] && /usr/bin/firewall-cmd --state >/dev/null 2>&1; then
# Drop the legacy service name from every zone where it was enabled
# and add the new `cortex` service in its place. Operators who never
# ran firewall-cmd against either name see no zone change.
for zone in $(/usr/bin/firewall-cmd --get-active-zones 2>/dev/null \
| awk '!/^[[:space:]]/ {print $1}'); do
if /usr/bin/firewall-cmd --permanent --zone="$zone" --query-service=helexa-cortex >/dev/null 2>&1; then
/usr/bin/firewall-cmd --permanent --zone="$zone" --remove-service=helexa-cortex >/dev/null 2>&1 || :
/usr/bin/firewall-cmd --permanent --zone="$zone" --add-service=cortex >/dev/null 2>&1 || :
fi
done
/usr/bin/firewall-cmd --reload >/dev/null 2>&1 || :
fi
:
%files
%license LICENSE
%doc README.md
%{_bindir}/cortex
%{_unitdir}/cortex.service
%{_sysusersdir}/cortex.conf
%{_prefix}/lib/firewalld/services/cortex.xml
%dir %{_sysconfdir}/cortex
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
%config(noreplace) %{_sysconfdir}/cortex/models.toml
%changelog
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
- chore: ignore local deploy script
- chore: move default ports out of common-collision ranges
- ci: drop actions/cache for cargo registry and target
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
- ci: publish both packages to a single helexa/helexa COPR project
- fix(rpm): rename neuron package to helexa-neuron
- ci: commit generated %changelog entries back to main
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
- Initial package

View File

@@ -5,7 +5,7 @@ use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(name = "cortex")]
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
#[command(version)]
struct Cli {
#[command(subcommand)]
@@ -23,7 +23,7 @@ enum Commands {
/// Print the fleet status (models, nodes, health).
Status {
/// Gateway API endpoint to query.
#[arg(short, long, default_value = "http://localhost:31313")]
#[arg(short, long, default_value = "http://localhost:8000")]
endpoint: String,
},
}

View File

@@ -2,7 +2,7 @@
//!
//! These mirror the `/v1/messages` format used by the Anthropic API.
//! The gateway accepts these, translates to OpenAI format, proxies to
//! the inference backend (neuron), then translates the response back.
//! mistral.rs, then translates the response back.
use serde::{Deserialize, Serialize};
use serde_json::Value;

View File

@@ -1,6 +1,5 @@
//! Model catalogue — profiles describing how to serve each model.
use crate::discovery::DeviceInfo;
use serde::{Deserialize, Serialize};
use std::path::Path;
@@ -65,103 +64,4 @@ impl ModelCatalogue {
.iter()
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
}
/// Find a profile by model id.
pub fn get(&self, model_id: &str) -> Option<&ModelProfile> {
self.models.iter().find(|p| p.id == model_id)
}
}
impl ModelProfile {
/// True iff this profile's placement constraints can be satisfied
/// by the named neuron with the given device topology.
///
/// Constraints checked:
/// - `pinned_on`: non-empty → neuron must be on the list.
/// - `min_devices`: neuron must have at least this many devices.
/// - `min_device_vram_mb`: at least `min_devices` of the neuron's
/// devices must each meet this VRAM floor.
pub fn is_feasible_on(&self, neuron_name: &str, devices: &[DeviceInfo]) -> bool {
if !self.pinned_on.is_empty() && !self.pinned_on.iter().any(|n| n == neuron_name) {
return false;
}
if (devices.len() as u32) < self.min_devices {
return false;
}
if let Some(min_vram) = self.min_device_vram_mb {
let big_enough = devices
.iter()
.filter(|d| d.vram_total_mb >= min_vram)
.count() as u32;
if big_enough < self.min_devices {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::DeviceInfo;
fn device(idx: u32, vram_mb: u64) -> DeviceInfo {
DeviceInfo {
index: idx,
name: format!("DEV-{idx}"),
vram_total_mb: vram_mb,
compute_capability: "8.6".into(),
}
}
fn profile() -> ModelProfile {
ModelProfile {
id: "Qwen/Qwen3.6-27B".into(),
harness: "candle".into(),
quant: None,
vram_mb: Some(45_000),
min_devices: 2,
min_device_vram_mb: Some(24_000),
pinned_on: vec![],
}
}
#[test]
fn feasible_when_two_devices_meet_vram_floor() {
let p = profile();
let devices = [device(0, 32_000), device(1, 32_000)];
assert!(p.is_feasible_on("beast", &devices));
}
#[test]
fn infeasible_when_only_one_device() {
let p = profile();
let devices = [device(0, 64_000)];
assert!(!p.is_feasible_on("benjy", &devices));
}
#[test]
fn infeasible_when_one_device_underspec() {
let p = profile();
let devices = [device(0, 32_000), device(1, 12_000)];
assert!(!p.is_feasible_on("mixed", &devices));
}
#[test]
fn pinned_on_excludes_other_neurons() {
let mut p = profile();
p.pinned_on = vec!["beast".into()];
let devices = [device(0, 32_000), device(1, 32_000)];
assert!(p.is_feasible_on("beast", &devices));
assert!(!p.is_feasible_on("benjy", &devices));
}
#[test]
fn no_vram_floor_just_needs_min_devices() {
let mut p = profile();
p.min_device_vram_mb = None;
let devices = [device(0, 1_000), device(1, 1_000)];
assert!(p.is_feasible_on("anywhere", &devices));
}
}

View File

@@ -22,9 +22,9 @@ fn default_models_path() -> String {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewaySettings {
/// Address to listen on for API requests (e.g. "0.0.0.0:31313")
/// Address to listen on for API requests (e.g. "0.0.0.0:8000")
pub listen: String,
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:31314")
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:9100")
pub metrics_listen: String,
}
@@ -50,7 +50,7 @@ pub enum EvictionStrategy {
pub struct NeuronEndpoint {
/// Human-readable node name (e.g. "beast")
pub name: String,
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131")
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090")
pub endpoint: String,
}
@@ -70,8 +70,8 @@ impl Default for GatewayConfig {
fn default() -> Self {
Self {
gateway: GatewaySettings {
listen: "0.0.0.0:31313".into(),
metrics_listen: "0.0.0.0:31314".into(),
listen: "0.0.0.0:8000".into(),
metrics_listen: "0.0.0.0:9100".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,

View File

@@ -9,13 +9,13 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
/// Configuration for a harness instance on a neuron.
///
/// All current harnesses are in-process (candle); per-harness tuning
/// (cache paths, device policies, etc.) lives in dedicated config
/// blocks rather than on this struct.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HarnessConfig {
pub name: String,
/// Base URL of the harness (e.g. "http://localhost:8080" for mistral.rs).
pub endpoint: Option<String>,
/// Systemd unit name, if the harness is managed via systemd.
pub systemd_unit: Option<String>,
}
/// Health status of a harness process.
@@ -47,24 +47,16 @@ pub struct ModelInfo {
}
/// What an inference harness must do, from neuron's perspective.
///
/// All current harnesses are in-process — they share neuron's address
/// space and lifecycle. `start`/`stop` therefore default to no-ops; a
/// future process-supervising harness would override them.
#[async_trait]
pub trait Harness: Send + Sync {
/// Human-readable name (e.g. "candle").
/// Human-readable name (e.g. "mistralrs", "llamacpp", "comfyui").
fn name(&self) -> &str;
/// Start the harness. Default no-op for in-process harnesses.
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
Ok(())
}
/// Start the harness process if it is not already running.
async fn start(&self, config: &HarnessConfig) -> Result<()>;
/// Stop the harness. Default no-op for in-process harnesses.
async fn stop(&self) -> Result<()> {
Ok(())
}
/// Stop the harness process gracefully.
async fn stop(&self) -> Result<()>;
/// Health check. Returns the harness process status.
async fn health(&self) -> HarnessHealth;

View File

@@ -1,4 +1,3 @@
use crate::discovery::DiscoveryResponse;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@@ -7,19 +6,13 @@ use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct NodeState {
pub name: String,
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131").
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090").
pub endpoint: String,
pub healthy: bool,
pub models: HashMap<String, ModelEntry>,
/// Number of load/unload cycles since last process restart.
pub lifecycle_cycles: u32,
pub last_poll: Option<DateTime<Utc>>,
/// Result of the most recent successful `GET /discovery` against
/// this neuron. Cached forever once obtained — device topology is
/// invariant for a given neuron process. `None` until the first
/// successful poll. Used by the router and `/v1/models` to do
/// catalogue × topology feasibility checks.
pub discovery: Option<DiscoveryResponse>,
}
/// A model registered on a node, with its runtime status.
@@ -43,32 +36,12 @@ pub enum ModelStatus {
}
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
///
/// The first four fields (`id`, `object`, `created`, `owned_by`) match
/// OpenAI's `/v1/models` shape verbatim, so existing OpenAI-aware
/// tooling deserialises this without custom code. The remaining fields
/// are helexa-specific extensions — OpenAI clients ignore unknown
/// fields and other consumers can read them for placement / debugging.
/// Includes which node(s) host this model and their status.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CortexModelEntry {
pub id: String,
/// Always `"model"` per OpenAI's contract.
pub object: String,
/// Unix-second timestamp; cortex stamps this at response time.
pub created: u64,
/// OpenAI's "publisher" field — `"helexa"` for everything we serve.
pub owned_by: String,
/// True if any neuron currently has this model loaded. False for
/// catalogue entries that are feasible but not yet loaded.
pub loaded: bool,
/// Neurons whose discovered topology can satisfy this model's
/// catalogue placement constraints. Empty for models that are
/// loaded somewhere but not present in the catalogue (cortex has
/// no feasibility opinion on those).
pub feasible_on: Vec<String>,
/// Where this model is actually loaded right now. Subset of (or
/// disjoint from) `feasible_on` depending on whether the catalogue
/// covers this model.
/// Which nodes have this model (and their status).
pub locations: Vec<ModelLocation>,
}

View File

@@ -3,7 +3,7 @@
//! These are a subset sufficient for chat completions (streaming + non-streaming).
//! Fields not relevant to proxying are captured as `serde_json::Value` via
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
//! extension field a backend might support.
//! extension field mistral.rs supports.
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -22,7 +22,7 @@ pub struct ChatCompletionRequest {
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
/// All other fields (tools, response_format, backend extensions, etc.)
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
#[serde(flatten)]
pub extra: Value,
}

View File

@@ -24,7 +24,6 @@ tokio-stream.workspace = true
eventsource-stream.workspace = true
bytes = "1"
urlencoding = "2"
url = "2"
[dev-dependencies]
tokio = { workspace = true, features = ["test-util"] }

View File

@@ -34,30 +34,12 @@ async fn chat_completions(
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => {
tracing::warn!(
handler = "chat_completions",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "chat_completions",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(404, &e.to_string());
}
Err(e) => return error_response(404, &e.to_string()),
};
touch_model(&fleet, &route.node_name, &model_id).await;
@@ -81,30 +63,12 @@ async fn completions(
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => {
tracing::warn!(
handler = "completions",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "completions",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(404, &e.to_string());
}
Err(e) => return error_response(404, &e.to_string()),
};
touch_model(&fleet, &route.node_name, &model_id).await;
@@ -121,14 +85,7 @@ async fn anthropic_messages(
// Parse as Anthropic request.
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "anthropic_messages",
error = %e,
"rejected: invalid Anthropic request body"
);
return error_response(400, "invalid Anthropic request body");
}
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
};
let model_id = anth_req.model.clone();
@@ -138,32 +95,12 @@ async fn anthropic_messages(
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
let openai_body = match serde_json::to_vec(&openai_req) {
Ok(b) => Bytes::from(b),
Err(e) => {
tracing::error!(
handler = "anthropic_messages",
model = %model_id,
error = %e,
"internal: failed to serialise translated OpenAI request"
);
return error_response(500, "internal translation error");
}
Err(e) => return error_response(500, &format!("translation error: {e}")),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(404, &e.to_string());
}
Err(e) => return error_response(404, &e.to_string()),
};
touch_model(&fleet, &route.node_name, &model_id).await;
@@ -196,25 +133,14 @@ async fn anthropic_messages(
Ok(resp) => resp,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
// forward_request already warn'd with the wire-level
// detail; no need to log again here.
e.into_response()
}
}
} else {
// Non-streaming: proxy, buffer full response, translate back to Anthropic.
let target_url = format!("{}/v1/chat/completions", route.endpoint);
tracing::info!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
cold_start = route.cold_start,
"proxying request"
);
let upstream_resp = fleet
.http_client
.post(&target_url)
.post(format!("{}/v1/chat/completions", route.endpoint))
.body(openai_body)
.header("content-type", "application/json")
.send()
@@ -224,49 +150,22 @@ async fn anthropic_messages(
Ok(r) => r,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
"upstream request failed (network)"
);
return error_response(502, "upstream request failed");
return error_response(502, &format!("upstream request failed: {e}"));
}
};
let upstream_status = upstream_resp.status();
if !upstream_status.is_success() {
if !upstream_resp.status().is_success() {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
let status = upstream_status.as_u16();
let status = upstream_resp.status().as_u16();
let body = upstream_resp.text().await.unwrap_or_default();
let body_snippet = body.chars().take(512).collect::<String>();
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
status,
body = %body_snippet,
"upstream returned non-2xx"
);
return error_response(status, &format!("upstream returned {status}"));
return error_response(status, &format!("upstream error: {body}"));
}
let body_bytes = match upstream_resp.bytes().await {
Ok(b) => b,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
"failed to read upstream response body"
);
return error_response(502, "failed to read upstream response");
return error_response(502, &format!("failed to read upstream response: {e}"));
}
};
@@ -275,20 +174,7 @@ async fn anthropic_messages(
Ok(r) => r,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
let body_snippet = String::from_utf8_lossy(&body_bytes)
.chars()
.take(512)
.collect::<String>();
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
body = %body_snippet,
"failed to parse upstream response as OpenAI ChatCompletionResponse"
);
return error_response(502, "malformed upstream response");
return error_response(502, &format!("failed to parse upstream response: {e}"));
}
};
@@ -299,62 +185,12 @@ async fn anthropic_messages(
}
}
/// `GET /v1/models` — union of (catalogue × topology feasibility) and
/// (currently loaded somewhere). The result is what the fleet *could*
/// serve, not just what's already loaded — so OpenAI-compatible tools
/// see every model the operator has provisioned, and cortex
/// transparently cold-loads the first time one is requested.
/// `GET /v1/models` — aggregate models from all nodes.
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
use std::collections::HashMap;
let now = Utc::now().timestamp() as u64;
let nodes = fleet.nodes.read().await;
let catalogue = &fleet.catalogue;
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
std::collections::HashMap::new();
let mut entries: HashMap<String, CortexModelEntry> = HashMap::new();
// Pass 1: catalogue × topology. For every catalogue profile, find
// healthy neurons whose discovered devices satisfy the profile.
// Catalogue-defined models surface here even if nothing has loaded
// them yet — that's the point of the unified endpoint.
for profile in &catalogue.models {
let mut feasible_on = Vec::new();
for node in nodes.values() {
if !node.healthy {
continue;
}
let Some(disc) = node.discovery.as_ref() else {
continue;
};
if profile.is_feasible_on(&node.name, &disc.devices) {
feasible_on.push(node.name.clone());
}
}
if feasible_on.is_empty() {
// The catalogue lists this model but no neuron's topology
// matches — surface it as not-loaded with no feasible
// location. Hides nothing; lets operators see why a
// configured model isn't reachable.
feasible_on.clear();
}
entries.insert(
profile.id.clone(),
CortexModelEntry {
id: profile.id.clone(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: false,
feasible_on,
locations: Vec::new(),
},
);
}
// Pass 2: layer the actually-loaded state on top. For each
// (node, model) entry, attach a ModelLocation. If the model isn't
// in the catalogue, create a new CortexModelEntry from scratch —
// cortex doesn't refuse to surface a manually-loaded model just
// because the operator didn't enumerate it in models.toml.
for node in nodes.values() {
for (model_id, entry) in &node.models {
let location = ModelLocation {
@@ -362,30 +198,19 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
status: entry.status,
vram_estimate_mb: entry.vram_estimate_mb,
};
let was_loaded = matches!(entry.status, cortex_core::node::ModelStatus::Loaded);
entries
model_map
.entry(model_id.clone())
.and_modify(|e| {
e.locations.push(location.clone());
if was_loaded {
e.loaded = true;
}
})
.and_modify(|e| e.locations.push(location.clone()))
.or_insert_with(|| CortexModelEntry {
id: model_id.clone(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: was_loaded,
// Not in catalogue — cortex has no opinion on
// feasibility; leave empty.
feasible_on: Vec::new(),
locations: vec![location],
});
}
}
let data: Vec<Value> = entries.values().map(|e| json!(e)).collect();
let data: Vec<Value> = model_map.values().map(|e| json!(e)).collect();
Json(json!({
"object": "list",
"data": data,
@@ -440,9 +265,6 @@ async fn proxy_with_metrics(
}
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
// proxy::forward_request already warn'd with wire-level
// detail (target URL, error, status). ProxyError::into_response
// now returns a generic message — no body leak.
e.into_response()
}
}

View File

@@ -3,7 +3,6 @@
use crate::state::CortexState;
use chrono::Utc;
use cortex_core::discovery::DiscoveryResponse;
use cortex_core::harness::ModelInfo;
use cortex_core::node::{ModelEntry, ModelStatus};
use std::sync::Arc;
@@ -26,59 +25,7 @@ pub async fn poll_once(fleet: &CortexState) {
}
}
/// One-shot fetch of `GET /discovery`. Cached on the NodeState forever
/// after the first success — topology is invariant for a given neuron
/// process. Skipped when the cache is already populated.
async fn maybe_poll_discovery(fleet: &CortexState, name: &str, endpoint: &str) {
{
let nodes = fleet.nodes.read().await;
match nodes.get(name) {
Some(n) if n.discovery.is_some() => return,
_ => {}
}
}
let url = format!("{endpoint}/discovery");
let resp = match fleet
.http_client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await
{
Ok(r) if r.status().is_success() => r,
Ok(r) => {
tracing::debug!(node = name, status = %r.status(), "discovery probe non-success");
return;
}
Err(e) => {
tracing::debug!(node = name, error = %e, "discovery probe unreachable");
return;
}
};
match resp.json::<DiscoveryResponse>().await {
Ok(d) => {
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(name) {
tracing::info!(
node = name,
hostname = %d.hostname,
devices = d.devices.len(),
"discovery cached"
);
node.discovery = Some(d);
}
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to parse /discovery response");
}
}
}
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
// Topology first — cheap once cached, and the router needs it to
// route requests against catalogue entries that aren't loaded yet.
maybe_poll_discovery(fleet, name, endpoint).await;
let url = format!("{endpoint}/models");
let result = fleet

View File

@@ -1,4 +1,4 @@
//! Streaming HTTP reverse proxy to neuron backends.
//! Streaming HTTP reverse proxy to mistral.rs backends.
//!
//! For streaming requests, SSE chunks are forwarded as they arrive.
//! The proxy captures timing information for metrics but does not
@@ -12,13 +12,6 @@ use axum::response::{IntoResponse, Response};
use reqwest::Client;
/// Proxy a request body to the resolved backend node and stream the response.
///
/// Logging contract: every call emits exactly one structured event at
/// info / warn level for operator visibility, regardless of outcome.
/// Network-level failures and non-2xx upstream statuses are warn'd here
/// (closest to the wire); the user-facing response carries only the
/// status code and a generic message — implementation detail (body,
/// error chain) lives in the log, never in the API surface.
pub async fn forward_request(
client: &Client,
route: &RouteDecision,
@@ -44,33 +37,10 @@ pub async fn forward_request(
req_builder = req_builder.header(key, value);
}
let upstream_resp = match req_builder.send().await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
node = %route.node_name,
url = %url,
error = %e,
"proxy: upstream request failed (network)"
);
return Err(ProxyError::Upstream(e));
}
};
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
let upstream_status = upstream_resp.status();
if !upstream_status.is_success() {
// Streaming body — can't snippet without breaking the stream
// pass-through. Log status + URL; the client still gets the
// upstream status, just without the leaked body.
tracing::warn!(
node = %route.node_name,
url = %url,
status = upstream_status.as_u16(),
"proxy: upstream returned non-2xx"
);
}
let status = StatusCode::from_u16(upstream_status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let status =
StatusCode::from_u16(upstream_resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let resp_headers = upstream_resp.headers().clone();
let stream = upstream_resp.bytes_stream();
@@ -82,37 +52,28 @@ pub async fn forward_request(
response = response.header(key, value);
}
response.body(body).map_err(|e| {
tracing::warn!(
node = %route.node_name,
url = %url,
error = %e,
"proxy: failed to build response"
);
ProxyError::ResponseBuild(e.to_string())
})
response
.body(body)
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
}
#[derive(Debug, thiserror::Error)]
pub enum ProxyError {
#[error("upstream request failed")]
#[error("upstream request failed: {0}")]
Upstream(reqwest::Error),
#[error("failed to build response")]
#[error("failed to build response: {0}")]
ResponseBuild(String),
}
impl IntoResponse for ProxyError {
fn into_response(self) -> Response {
let (status, message) = match &self {
ProxyError::Upstream(_) => (StatusCode::BAD_GATEWAY, "upstream request failed"),
ProxyError::ResponseBuild(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"failed to build response",
),
let status = match &self {
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
let body = serde_json::json!({
"error": {
"message": message,
"message": self.to_string(),
"type": "proxy_error",
}
});

View File

@@ -2,21 +2,13 @@
//!
//! Given a model ID from an inbound request, determine which node should
//! handle it. Priority:
//! 1. Node where the model is currently `Loaded` → use it.
//! 2. Node where the model is `Unloaded` → use it; neuron's existing
//! lazy-load behaviour will reload before serving the request.
//! 3. Model is in the catalogue → pick a feasible neuron, call
//! `POST /models/load`, wait for the load to complete, then
//! proxy. First-request cold-load latency is acceptable per the
//! unified-endpoint contract.
//! 4. Not in catalogue, not loaded anywhere → 404.
//! 1. Node where the model is currently `Loaded`
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
//! 3. Error: model not found on any node
use crate::state::CortexState;
use cortex_core::catalogue::ModelProfile;
use cortex_core::harness::ModelSpec;
use cortex_core::node::ModelStatus;
use std::sync::Arc;
use std::time::Duration;
/// The routing decision: which node endpoint to proxy the request to.
#[derive(Debug, Clone)]
@@ -24,31 +16,18 @@ pub struct RouteDecision {
pub node_name: String,
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
pub endpoint: String,
/// Whether the model will need to load (cold start). Set to true
/// when we proxied to an `Unloaded` node (lazy load on neuron) or
/// when we just triggered an explicit cold-load via the catalogue
/// path.
/// Whether the model will need to load (cold start).
pub cold_start: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum RouteError {
#[error("model '{0}' not found on any node and not in catalogue")]
#[error("model '{0}' not found on any node")]
ModelNotFound(String),
#[error("no healthy nodes available")]
NoHealthyNodes,
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
EndpointResolveFailed(String, String),
#[error(
"model '{model_id}' is in the catalogue but no healthy neuron's topology satisfies its constraints"
)]
NoFeasibleNeuron { model_id: String },
#[error("cold-load of '{model_id}' on '{node}' failed: {message}")]
ColdLoadFailed {
model_id: String,
node: String,
message: String,
},
}
/// Resolve which node should serve a request for the given model.
@@ -57,231 +36,42 @@ pub async fn resolve(
fleet: &Arc<CortexState>,
model_id: &str,
) -> Result<RouteDecision, RouteError> {
// Snapshot loaded / unloaded state from the poller cache.
let (loaded_route, unloaded_route, any_healthy) = {
let (node_name, neuron_endpoint, cold_start) = {
let nodes = fleet.nodes.read().await;
let mut loaded_route = None;
let mut unloaded_route = None;
let mut any_healthy = false;
let mut loaded_candidate = None;
let mut unloaded_candidate = None;
for node in nodes.values() {
if !node.healthy {
continue;
}
any_healthy = true;
if let Some(entry) = node.models.get(model_id) {
match entry.status {
ModelStatus::Loaded | ModelStatus::Reloading => {
loaded_route = Some((node.name.clone(), node.endpoint.clone(), false));
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false));
break;
}
ModelStatus::Unloaded => {
if unloaded_route.is_none() {
unloaded_route = Some((node.name.clone(), node.endpoint.clone(), true));
if unloaded_candidate.is_none() {
unloaded_candidate =
Some((node.name.clone(), node.endpoint.clone(), true));
}
}
}
}
}
(loaded_route, unloaded_route, any_healthy)
};
if !any_healthy {
return Err(RouteError::NoHealthyNodes);
}
// Priority 1: already loaded.
if let Some((node_name, neuron_endpoint, cold_start)) = loaded_route {
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
}
// Priority 2: known to neuron but unloaded (neuron's lazy load).
if let Some((node_name, neuron_endpoint, cold_start)) = unloaded_route {
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
}
// Priority 3: catalogue × topology cold-load.
if let Some(profile) = fleet.catalogue.get(model_id) {
let (node_name, neuron_endpoint) = pick_feasible_neuron(fleet, profile).await?;
cold_load(fleet, &node_name, &neuron_endpoint, profile).await?;
return finish(fleet, &node_name, &neuron_endpoint, model_id, true).await;
}
Err(RouteError::ModelNotFound(model_id.to_string()))
}
/// Pick a healthy neuron whose discovered topology satisfies the
/// profile. Preference order:
/// 1. A neuron from `profile.pinned_on` that is healthy + feasible.
/// 2. Otherwise, any healthy + feasible neuron, stable by name.
async fn pick_feasible_neuron(
fleet: &Arc<CortexState>,
profile: &ModelProfile,
) -> Result<(String, String), RouteError> {
let nodes = fleet.nodes.read().await;
let mut candidates: Vec<(String, String, bool)> = Vec::new();
for node in nodes.values() {
if !node.healthy {
continue;
}
let Some(disc) = node.discovery.as_ref() else {
continue;
};
if !profile.is_feasible_on(&node.name, &disc.devices) {
continue;
}
let pinned = profile.pinned_on.iter().any(|n| n == &node.name);
candidates.push((node.name.clone(), node.endpoint.clone(), pinned));
}
candidates.sort_by(|a, b| {
b.2.cmp(&a.2) // pinned first (true > false)
.then(a.0.cmp(&b.0))
});
let pick = candidates.into_iter().next();
pick.map(|(n, e, _)| (n, e))
.ok_or_else(|| RouteError::NoFeasibleNeuron {
model_id: profile.id.clone(),
})
}
/// Issue `POST {endpoint}/models/load` for this profile on this neuron,
/// blocking until the load completes (neuron's load endpoint is
/// synchronous — it returns 200 once VRAM is materialised). On success
/// also inserts a `Loaded` entry into the local NodeState cache so the
/// caller's subsequent endpoint lookup sees the new model without
/// waiting for the next poll cycle.
async fn cold_load(
fleet: &Arc<CortexState>,
node_name: &str,
neuron_endpoint: &str,
profile: &ModelProfile,
) -> Result<(), RouteError> {
let spec = profile_to_spec(fleet, node_name, profile).await;
let url = format!("{neuron_endpoint}/models/load");
tracing::info!(model = %profile.id, node = node_name, "cold-loading via /models/load");
// Generous timeout: a fresh download + safetensors mmap + device
// copy for a 30B-class dense model can comfortably exceed 5 min on
// a slow link. The HTTP client's own default already covers most
// of this; pin a longer per-request bound just here.
let resp = match fleet
.http_client
.post(&url)
.timeout(Duration::from_secs(1800))
.json(&spec)
.send()
.await
{
Ok(r) => r,
Err(e) => {
return Err(RouteError::ColdLoadFailed {
model_id: profile.id.clone(),
node: node_name.to_string(),
message: format!("HTTP request failed: {e}"),
});
}
};
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
// Neuron returns 400 "already loaded" when two concurrent
// requests race the same model. Treat that as success — both
// requests effectively achieved the same end state.
if body.contains("already loaded") {
tracing::info!(
model = %profile.id,
node = node_name,
"cold-load saw 'already loaded' — treating as success"
);
} else {
return Err(RouteError::ColdLoadFailed {
model_id: profile.id.clone(),
node: node_name.to_string(),
message: format!("HTTP {status}: {body}"),
});
}
} else {
tracing::info!(model = %profile.id, node = node_name, "cold-load returned 200");
}
// Warm the cache: insert a Loaded ModelEntry so the next
// resolve() finds the model without waiting for the poll loop.
{
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(node_name) {
node.models.insert(
profile.id.clone(),
cortex_core::node::ModelEntry {
id: profile.id.clone(),
status: ModelStatus::Loaded,
last_accessed: Some(chrono::Utc::now()),
vram_estimate_mb: profile.vram_mb,
},
);
}
}
Ok(())
}
/// Translate a `ModelProfile` to a `ModelSpec` neuron's /models/load
/// accepts. Devices are picked from the neuron's discovered topology —
/// the first `min_devices` indices that meet `min_device_vram_mb`.
async fn profile_to_spec(
fleet: &Arc<CortexState>,
node_name: &str,
profile: &ModelProfile,
) -> ModelSpec {
let devices = {
let nodes = fleet.nodes.read().await;
let mut picked: Vec<u32> = Vec::new();
if let Some(node) = nodes.get(node_name)
&& let Some(disc) = &node.discovery
{
let min_vram = profile.min_device_vram_mb.unwrap_or(0);
for d in &disc.devices {
if d.vram_total_mb >= min_vram {
picked.push(d.index);
if picked.len() as u32 >= profile.min_devices {
break;
}
}
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
if nodes.values().any(|n| n.healthy) {
RouteError::ModelNotFound(model_id.to_string())
} else {
RouteError::NoHealthyNodes
}
}
if picked.is_empty() {
// Fall back to a 0..min_devices default; pick_feasible_neuron
// already verified the topology satisfies the constraints,
// so this only fires if discovery raced or was lost.
(0..profile.min_devices).collect()
} else {
picked
}
})?
};
let tensor_parallel = if profile.min_devices > 1 {
Some(profile.min_devices)
} else {
None
};
ModelSpec {
model_id: profile.id.clone(),
harness: profile.harness.clone(),
quant: profile.quant.clone(),
tensor_parallel,
devices: Some(devices),
}
}
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
/// build the final `RouteDecision`. Shared by all three priority
/// branches above.
async fn finish(
fleet: &Arc<CortexState>,
node_name: &str,
neuron_endpoint: &str,
model_id: &str,
cold_start: bool,
) -> Result<RouteDecision, RouteError> {
// Ask the neuron for the inference endpoint for this model.
let endpoint_url = format!(
"{}/models/{}/endpoint",
neuron_endpoint,
@@ -299,82 +89,13 @@ async fn finish(
_ => None,
};
let raw = inference_endpoint.ok_or_else(|| {
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.to_string())
let endpoint = inference_endpoint.ok_or_else(|| {
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone())
})?;
// Rewrite loopback inference URLs to use the configured neuron host.
// Neuron's default bind_url is `http://localhost:13131` (it can't
// reliably know its own externally-resolvable name). Cortex sees a
// URL that's only meaningful from the neuron host's own perspective;
// proxying directly to localhost from a different cortex host would
// hit nothing. Keep neuron's port and path (a future harness could
// serve inference on a different port than the management API), but
// swap the host for the one in cortex.toml.
let endpoint = rewrite_loopback_host(&raw, neuron_endpoint).unwrap_or(raw);
Ok(RouteDecision {
node_name: node_name.to_string(),
node_name,
endpoint,
cold_start,
})
}
/// If `inference_url`'s host is a loopback name (localhost / 127.0.0.1 /
/// 0.0.0.0 / ::1), return a copy with the host replaced by
/// `neuron_endpoint`'s host. Otherwise return None and the caller falls
/// back to the inference URL as-is.
fn rewrite_loopback_host(inference_url: &str, neuron_endpoint: &str) -> Option<String> {
let inf = url::Url::parse(inference_url).ok()?;
let inf_host = inf.host_str()?;
let is_loopback = matches!(inf_host, "localhost" | "127.0.0.1" | "0.0.0.0" | "::1");
if !is_loopback {
return None;
}
let neuron = url::Url::parse(neuron_endpoint).ok()?;
let new_host = neuron.host_str()?;
let mut out = inf.clone();
out.set_host(Some(new_host)).ok()?;
// url::Url::to_string normalises an empty path to "/", which then
// breaks downstream callers that do format!("{endpoint}/v1/...")
// and produce a double slash. The proxy URL is treated as a base
// string that the caller appends paths to, so strip the trailing
// slash here.
let s = out.to_string();
Some(s.trim_end_matches('/').to_string())
}
#[cfg(test)]
mod tests {
use super::rewrite_loopback_host;
#[test]
fn rewrites_localhost_keeps_port_and_path() {
let out = rewrite_loopback_host(
"http://localhost:13131",
"http://beast.hanzalova.internal:13131",
);
assert_eq!(
out.as_deref(),
Some("http://beast.hanzalova.internal:13131")
);
}
#[test]
fn rewrites_loopback_with_distinct_inference_port() {
let out = rewrite_loopback_host("http://127.0.0.1:8080", "http://beast.lan:13131");
assert_eq!(out.as_deref(), Some("http://beast.lan:8080"));
}
#[test]
fn leaves_non_loopback_alone() {
let out = rewrite_loopback_host("http://other.host:1234", "http://beast.lan:13131");
assert_eq!(out, None);
}
#[test]
fn malformed_inference_url_returns_none() {
let out = rewrite_loopback_host("not a url", "http://beast.lan:13131");
assert_eq!(out, None);
}
}

View File

@@ -26,7 +26,6 @@ impl CortexState {
models: HashMap::new(),
lifecycle_cycles: 0,
last_poll: None,
discovery: None,
},
);
}

View File

@@ -22,7 +22,6 @@ use tokio::net::TcpListener;
/// - GET /models/:id/endpoint (returns the inference URL)
/// - POST /models/unload (accepts unload requests)
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
///
/// Returns the neuron base URL.
pub async fn spawn_mock_neuron() -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
@@ -55,7 +54,7 @@ pub async fn spawn_mock_neuron() -> String {
async fn mock_neuron_list_models() -> Json<Value> {
Json(json!([
{"id": "test-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
{"id": "test-model", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
]))
}

View File

@@ -12,8 +12,8 @@ use std::sync::Arc;
async fn test_poller_discovers_models() {
// Mock neuron reports 2 models via /models endpoint (neuron format).
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "model-a", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
{"id": "model-b", "harness": "candle", "status": "unloaded", "devices": [], "vram_used_mb": null}
{"id": "model-a", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
{"id": "model-b", "harness": "mistralrs", "status": "unloaded", "devices": [], "vram_used_mb": null}
]))
.await;
@@ -63,8 +63,8 @@ async fn test_poller_discovers_models() {
#[tokio::test]
async fn test_poller_updates_gateway_models_endpoint() {
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "model-x", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "model-y", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null}
{"id": "model-x", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "model-y", "harness": "mistralrs", "status": "loaded", "devices": [1], "vram_used_mb": null}
]))
.await;
@@ -152,8 +152,8 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
#[tokio::test]
async fn test_poller_removes_stale_models() {
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "drop-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "drop-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
]))
.await;
@@ -183,7 +183,7 @@ async fn test_poller_removes_stale_models() {
// New mock with only one model.
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
]))
.await;

View File

@@ -51,18 +51,18 @@ async fn test_streaming_sse_passthrough() {
}
assert!(
chunks.len() > chunk_count,
"expected more than {} chunks (got {}): {:?}",
chunk_count,
chunks.len() >= chunk_count + 1,
"expected at least {} chunks (got {}): {:?}",
chunk_count + 1,
chunks.len(),
chunks,
);
assert_eq!(chunks.last().unwrap(), "[DONE]");
for (i, chunk) in chunks.iter().enumerate().take(chunk_count) {
for i in 0..chunk_count {
let chunk_json: serde_json::Value =
serde_json::from_str(chunk).expect("chunk should be valid JSON");
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
assert_eq!(
chunk_json["choices"][0]["delta"]["content"],
format!("token{i}")

View File

@@ -12,36 +12,6 @@ path = "src/lib.rs"
name = "neuron"
path = "src/main.rs"
[features]
default = []
# Enables CUDA acceleration in candle and the cudarc/nccl bindings the
# TP worker pool uses. Without this feature, candle compiles for CPU
# only, Device::new_cuda calls fall back to CPU, and TP Init/sanity
# requests return Error{kind="cuda_feature_not_enabled"}.
cuda = [
"candle-core/cuda",
"candle-core/nccl",
"candle-nn/cuda",
"candle-transformers/cuda",
"dep:cudarc",
"dep:half",
"dep:cudaforge",
]
# Use cuDNN for convolution / attention kernels. Requires CUDA.
cudnn = [
"cuda",
"candle-core/cudnn",
"candle-nn/cudnn",
"candle-transformers/cudnn",
]
# FlashAttention kernels. Requires CUDA.
flash-attn = [
"cuda",
"candle-transformers/flash-attn",
]
# Reserved for GPU-only integration tests in later stages.
cuda-integration = ["cuda"]
[dependencies]
cortex-core.workspace = true
tokio.workspace = true
@@ -54,44 +24,9 @@ tracing-subscriber.workspace = true
anyhow.workspace = true
async-trait.workspace = true
clap.workspace = true
thiserror.workspace = true
futures.workspace = true
tokio-stream.workspace = true
figment.workspace = true
toml.workspace = true
# candle for in-process inference. CUDA support is gated behind the
# crate's `cuda` feature (default off) so the workspace builds on
# non-CUDA hosts and CI runners.
candle-core = "0.10.2"
candle-nn = "0.10.2"
candle-transformers = "0.10.2"
# Direct dep on cudarc (matching candle's transitive version) so the
# TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on
# the `cuda` feature; same toolchain requirement as candle's CUDA path.
cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] }
# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle
# storages. Matches candle-core's pinned major version to avoid double-
# compiling the `half` crate at conflicting versions.
half = { version = "2.5", optional = true }
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
hf-hub = { version = "0.4", features = ["tokio"] }
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
# tp `fused_load` module to read per-rank slices of fused QKV tensors
# without materialising the full tensor on device.
safetensors = "0.7"
[dev-dependencies]
tokio = { workspace = true, features = ["test-util"] }
reqwest.workspace = true
[build-dependencies]
# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a`
# under the `cuda` feature. Matches mistralrs's upstream build setup
# (their `mistralrs-core/build.rs` uses the same constructor).
cudaforge = { version = "0.1", optional = true }
[package.metadata.docs.rs]
# Skip the CUDA path on docs.rs (it lacks nvcc).
no-default-features = true

View File

@@ -1,66 +0,0 @@
//! Build script: compile the CUDA kernels in `src/cuda/*.cu` into a
//! static library and link it under the `cuda` feature.
//!
//! Patterned on `EricLBuehler/mistral.rs::mistralrs-core/build.rs` —
//! same `cudaforge::KernelBuilder` invocation, same NVCC flag set.
fn main() {
#[cfg(feature = "cuda")]
{
use std::path::PathBuf;
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/cuda/");
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let mut builder = cudaforge::KernelBuilder::new()
.source_glob("src/cuda/*.cu")
.out_dir(&build_dir)
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--compiler-options")
.arg("-fPIC");
// sm_<80 doesn't have bf16 intrinsics for WMMA — gate the
// bf16-only kernels off in that case. (Mirrors upstream.)
if let Some(compute_cap) = builder.get_compute_cap()
&& compute_cap < 80
{
builder = builder.arg("-DNO_BF16_KERNEL");
}
let target = std::env::var("TARGET").unwrap();
let out_file = if target.contains("msvc") {
build_dir.join("neuroncuda.lib")
} else {
build_dir.join("libneuroncuda.a")
};
builder
.build_lib(out_file)
.expect("neuron cuda build failed");
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=neuroncuda");
println!("cargo:rustc-link-lib=dylib=cudart");
if target.contains("msvc") {
// No extra runtime library needed.
} else if target.contains("apple")
|| target.contains("freebsd")
|| target.contains("openbsd")
{
println!("cargo:rustc-link-lib=dylib=c++");
} else if target.contains("android") {
println!("cargo:rustc-link-lib=dylib=c++_shared");
} else {
println!("cargo:rustc-link-lib=dylib=stdc++");
}
}
}

View File

@@ -1,33 +1,23 @@
//! HTTP API handlers for the neuron daemon.
use crate::harness::HarnessRegistry;
use crate::harness::candle::{CandleHarness, InferenceError};
use crate::health::HealthCache;
use axum::Router;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Json};
use axum::routing::{get, post};
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
use cortex_core::harness::ModelSpec;
use cortex_core::openai::ChatCompletionRequest;
use futures::stream::{self, StreamExt};
use serde_json::{Value, json};
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
/// Shared state for the neuron HTTP server.
pub struct NeuronState {
pub discovery: DiscoveryResponse,
pub health_cache: Arc<HealthCache>,
pub registry: RwLock<HarnessRegistry>,
/// Typed handle to the candle harness for inference routes. Cached at
/// startup so `/v1/chat/completions` doesn't have to hold the registry
/// read lock or perform dyn-Trait dispatch per request.
pub candle: Option<Arc<CandleHarness>>,
}
/// Build the neuron API router.
@@ -39,7 +29,6 @@ pub fn neuron_routes() -> Router<Arc<NeuronState>> {
.route("/models/load", post(load_model))
.route("/models/unload", post(unload_model))
.route("/models/{model_id}/endpoint", get(model_endpoint))
.route("/v1/chat/completions", post(chat_completions))
}
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
@@ -56,7 +45,7 @@ async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse
Ok(models) => Json(json!(models)).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
@@ -69,22 +58,11 @@ async fn load_model(
let registry = state.registry.read().await;
match registry.load_model(&spec).await {
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
Err(e) => {
// Log the full anyhow chain server-side so journalctl shows
// the underlying failure (hf-hub timeout, permission denied,
// disk full, etc.) without needing to inspect the HTTP
// response body separately.
tracing::warn!(
model = %spec.model_id,
error = %format!("{e:#}"),
"load_model failed"
);
(
StatusCode::BAD_REQUEST,
Json(json!({"error": format!("{e:#}")})),
)
.into_response()
}
Err(e) => (
StatusCode::BAD_REQUEST,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
@@ -106,11 +84,7 @@ async fn unload_model(
let registry = state.registry.read().await;
match registry.unload_model(&model_id).await {
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
Err(e) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
}
}
@@ -128,61 +102,3 @@ async fn model_endpoint(
.into_response(),
}
}
/// OpenAI-compatible chat completions. Dispatches to streaming SSE when
/// `stream: true` is set on the request; otherwise returns a single
/// `ChatCompletionResponse`.
async fn chat_completions(
State(state): State<Arc<NeuronState>>,
Json(req): Json<ChatCompletionRequest>,
) -> impl IntoResponse {
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"error": "candle harness not enabled on this neuron"})),
)
.into_response();
};
if req.stream.unwrap_or(false) {
match candle.chat_completion_stream(req).await {
Ok(rx) => {
// Each chunk → one SSE `data: {json}` line. After the
// channel closes, append the OpenAI [DONE] terminator.
let body_stream = ReceiverStream::new(rx).map(|chunk| {
let body = serde_json::to_string(&chunk).unwrap_or_default();
Ok::<_, Infallible>(Event::default().data(body))
});
let done_stream =
stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
Sse::new(body_stream.chain(done_stream))
.keep_alive(KeepAlive::default())
.into_response()
}
Err(InferenceError::ModelNotLoaded(id)) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
}
} else {
match candle.chat_completion(req).await {
Ok(resp) => Json(resp).into_response(),
Err(InferenceError::ModelNotLoaded(id)) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
}
}
}

View File

@@ -1,12 +1,12 @@
//! Neuron configuration loaded from neuron.toml.
use cortex_core::harness::{HarnessConfig, ModelSpec};
use cortex_core::harness::HarnessConfig;
use figment::{
Figment,
providers::{Env, Format, Toml},
};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NeuronConfig {
@@ -14,35 +14,10 @@ pub struct NeuronConfig {
pub port: u16,
#[serde(default)]
pub harnesses: Vec<HarnessConfig>,
/// Per-harness configuration. Currently only `candle` is recognised.
#[serde(default)]
pub harness: HarnessSettings,
/// Models to auto-load when the neuron service activates. Each entry
/// is loaded sequentially before the HTTP listener binds. A failure
/// on any single entry logs a warning and proceeds — broken entries
/// don't prevent the rest of the fleet from starting.
#[serde(default)]
pub default_models: Vec<ModelSpec>,
}
/// Settings for individual harness implementations. Each harness owns
/// its own sub-table so users only configure the harnesses they enable.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HarnessSettings {
#[serde(default)]
pub candle: CandleHarnessConfig,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CandleHarnessConfig {
/// HuggingFace cache directory for model weights.
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
#[serde(default)]
pub hf_cache: Option<PathBuf>,
}
fn default_port() -> u16 {
13131
9090
}
impl NeuronConfig {
@@ -58,10 +33,8 @@ impl NeuronConfig {
impl Default for NeuronConfig {
fn default() -> Self {
Self {
port: 13131,
port: 9090,
harnesses: vec![],
harness: HarnessSettings::default(),
default_models: vec![],
}
}
}

View File

@@ -1,84 +0,0 @@
//! FFI declarations for the CUDA kernels in `gdn.cu`.
//!
//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs`
//! covering only the Gated DeltaNet kernels we currently use. Other
//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective
//! scan, etc.) would land here too as we absorb them.
//!
//! All function declarations are MIT-licensed from upstream and
//! unchanged apart from this header.
use std::ffi::c_void;
#[allow(dead_code)]
unsafe extern "C" {
// GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next.
pub(crate) fn gated_delta_rule_recurrence(
q: *const f32,
k: *const f32,
v: *const f32,
g: *const f32,
beta: *const f32,
state: *mut f32,
output: *mut f32,
bh: i32,
seq_len: i32,
k_dim: i32,
v_dim: i32,
stream: i64,
);
/// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks).
pub(crate) fn chunked_gated_delta_rule_recurrence(
q: *const f32,
k: *const f32,
v: *const f32,
g: *const f32,
beta: *const f32,
state: *mut f32,
output: *mut f32,
bh: i32,
seq_len: i32,
k_dim: i32,
v_dim: i32,
stream: i64,
);
pub(crate) fn causal_conv1d_update(
x: *const c_void,
weight: *const c_void,
conv_state: *mut c_void,
output: *mut c_void,
batch_size: i32,
conv_dim: i32,
kernel_size: i32,
dtype: i32,
stream: i64,
);
pub(crate) fn causal_conv1d_full(
x: *const c_void,
weight: *const c_void,
conv_state_out: *mut c_void,
output: *mut c_void,
batch_size: i32,
conv_dim: i32,
seq_len: i32,
kernel_size: i32,
dtype: i32,
stream: i64,
);
pub(crate) fn fused_gdn_gating(
b: *const c_void,
a: *const c_void,
a_log: *const f32,
dt_bias: *const f32,
beta_out: *mut c_void,
g_out: *mut c_void,
total_elements: i32,
num_heads: i32,
dtype: i32,
stream: i64,
);
}

View File

@@ -1,711 +0,0 @@
// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`).
//
// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this
// file are limited to this banner; the kernels are unchanged so a
// diff against upstream stays minimal.
//
// Five kernels exposed via `extern "C"` shims at the bottom:
// - gated_delta_rule_recurrence (per-token decode)
// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill)
// - causal_conv1d_update (single-token conv decode)
// - causal_conv1d_full (multi-token conv prefill)
// - fused_gdn_gating (beta = sigmoid(b);
// g = -exp(A_log) * softplus(a + dt_bias))
#include "cuda_bf16.h"
#include "cuda_fp16.h"
#include <cmath>
#include <cstdint>
#include <cuda_runtime.h>
// ============================================================================
// Kernel 1: gated_delta_rule_recurrence (optimized)
//
// V-tiled recurrence with compile-time K dimension for register residency.
// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of
// state. Shared memory holds k_buf and q_buf (2*BK floats).
//
// Optimizations over naive version:
// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30)
// - #pragma unroll on all k-loops -> full ILP
// - Fused decay+kv_mem pass and fused state_update+output pass
// - __fmaf_rn intrinsics for guaranteed fused multiply-add
// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere
//
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
// state: [BH, K, V] (in/out) output: [BH, S, V]
// ============================================================================
// Optimized kernel: BK known at compile time -> registers + full unrolling
template <int BK, int BV>
__global__ void gated_delta_rule_recurrence_kernel_tiled(
const float *__restrict__ q, // [BH, S, K]
const float *__restrict__ k, // [BH, S, K]
const float *__restrict__ v, // [BH, S, V]
const float *__restrict__ g, // [BH, S]
const float *__restrict__ beta, // [BH, S]
float *__restrict__ state, // [BH, K, V]
float *__restrict__ output, // [BH, S, V]
int seq_len, int v_dim) {
const int v_tile = blockIdx.x; // which V-tile
const int bh = blockIdx.y; // batch*head index
const int tid = threadIdx.x; // thread within tile [0, BV)
const int v_idx = v_tile * BV + tid; // global V index
if (v_idx >= v_dim)
return;
// Pointers for this (batch, head)
const float *q_bh = q + bh * seq_len * BK;
const float *k_bh = k + bh * seq_len * BK;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * BK * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
// Shared memory: k_buf[BK] + q_buf[BK]
__shared__ float k_buf[BK];
__shared__ float q_buf[BK];
// Load state column into registers — BK is compile-time, so this is
// a true register array (not spilled to local memory)
float s[BK];
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
for (int t = 0; t < seq_len; t++) {
// Collaboratively load k_t into shared memory
// BK / BV loads per thread (e.g. 128/64 = 2)
#pragma unroll
for (int j = tid; j < BK; j += BV) {
k_buf[j] = k_bh[t * BK + j];
}
__syncthreads();
// Load scalars for this timestep
float decay = expf(g_bh[t]);
float beta_t = beta_bh[t];
float v_t = v_bh[t * v_dim + v_idx];
// Fused pass 1: decay state + compute kv_mem
float kv_mem = 0.0f;
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] *= decay;
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
}
// Delta rule
float delta = (v_t - kv_mem) * beta_t;
// Collaboratively load q_t into shared memory
#pragma unroll
for (int j = tid; j < BK; j += BV) {
q_buf[j] = q_bh[t * BK + j];
}
__syncthreads();
// Fused pass 2: update state + compute output
float y_t = 0.0f;
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
}
out_bh[t * v_dim + v_idx] = y_t;
__syncthreads();
}
// Write state back
#pragma unroll
for (int j = 0; j < BK; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
// Fallback kernel: runtime k_dim, still V-tiled for occupancy
template <int BV, int MAX_K>
__global__ void gated_delta_rule_recurrence_kernel_fallback(
const float *__restrict__ q, const float *__restrict__ k,
const float *__restrict__ v, const float *__restrict__ g,
const float *__restrict__ beta, float *__restrict__ state,
float *__restrict__ output, int seq_len, int k_dim, int v_dim) {
const int v_tile = blockIdx.x;
const int bh = blockIdx.y;
const int tid = threadIdx.x;
const int v_idx = v_tile * BV + tid;
if (v_idx >= v_dim)
return;
const float *q_bh = q + bh * seq_len * k_dim;
const float *k_bh = k + bh * seq_len * k_dim;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * k_dim * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
extern __shared__ float shared[];
float *k_buf = shared;
float *q_buf = shared + k_dim;
float s[MAX_K];
for (int j = 0; j < k_dim; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
for (int t = 0; t < seq_len; t++) {
for (int j = tid; j < k_dim; j += BV) {
k_buf[j] = k_bh[t * k_dim + j];
}
__syncthreads();
float decay = expf(g_bh[t]);
float beta_t = beta_bh[t];
float v_t = v_bh[t * v_dim + v_idx];
float kv_mem = 0.0f;
for (int j = 0; j < k_dim; j++) {
s[j] *= decay;
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
}
float delta = (v_t - kv_mem) * beta_t;
for (int j = tid; j < k_dim; j += BV) {
q_buf[j] = q_bh[t * k_dim + j];
}
__syncthreads();
float y_t = 0.0f;
for (int j = 0; j < k_dim; j++) {
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
}
out_bh[t * v_dim + v_idx] = y_t;
__syncthreads();
}
for (int j = 0; j < k_dim; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
extern "C" void gated_delta_rule_recurrence(const float *q, const float *k,
const float *v, const float *g,
const float *beta, float *state,
float *output, int bh, int seq_len,
int k_dim, int v_dim,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
if (k_dim == 128) {
// Fast path for Qwen3-Next (k_dim=128)
constexpr int BK = 128;
constexpr int BV = 64;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
v_dim);
} else if (k_dim == 64) {
// Fast path for models with k_dim=64
constexpr int BK = 64;
constexpr int BV = 64;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
v_dim);
} else {
// Fallback for other k_dim values (runtime loop, still V-tiled)
constexpr int BV = 64;
constexpr int MAX_K = 256;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
size_t smem = 2 * k_dim * sizeof(float);
gated_delta_rule_recurrence_kernel_fallback<BV, MAX_K>
<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, k_dim, v_dim);
}
}
// ============================================================================
// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization)
//
// Processes prefill tokens in BT-token chunks instead of one at a time.
// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation,
// forward substitution (triangular solve), output computation, and state
// update.
//
// Same thread model as Kernel 1: one block per (v_tile, batch*head),
// one thread per V-column. Each thread owns BK registers of state.
//
// Shared memory holds:
// k_chunk[BT * BK] -- key vectors for current chunk
// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix
// gcum[BT] -- cumulative sum of g within chunk
// beta_s[BT] -- beta values for chunk
// q_buf[BK] -- q vector (loaded one row at a time)
//
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
// state: [BH, K, V] (in/out) output: [BH, S, V]
// ============================================================================
template <int BT, int BK, int BV>
__global__ void
chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K]
const float *__restrict__ k, // [BH, S, K]
const float *__restrict__ v, // [BH, S, V]
const float *__restrict__ g, // [BH, S]
const float *__restrict__ beta, // [BH, S]
float *__restrict__ state, // [BH, K, V]
float *__restrict__ output, // [BH, S, V]
int seq_len, int v_dim) {
const int v_tile = blockIdx.x;
const int bh = blockIdx.y;
const int tid = threadIdx.x;
const int v_idx = v_tile * BV + tid;
if (v_idx >= v_dim)
return;
const int num_chunks = (seq_len + BT - 1) / BT;
// Pointers for this (batch, head)
const float *q_bh = q + bh * seq_len * BK;
const float *k_bh = k + bh * seq_len * BK;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * BK * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
// Dynamic shared memory layout
extern __shared__ float smem[];
float *k_chunk = smem; // [BT * BK]
float *kk_dot = smem + BT * BK; // [BT * BT]
float *gcum = smem + BT * BK + BT * BT; // [BT]
float *beta_s = gcum + BT; // [BT]
float *q_buf = beta_s + BT; // [BK]
// Load state column into registers
float s[BK];
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
// Per-thread register array for corrected deltas
float delta[BT];
for (int c = 0; c < num_chunks; c++) {
const int chunk_start = c * BT;
const int chunk_len = min(BT, seq_len - chunk_start);
// === Phase 1: Cooperative load of k, beta, g into shared memory ===
for (int t = 0; t < chunk_len; t++) {
for (int j = tid; j < BK; j += BV) {
k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j];
}
}
if (tid < chunk_len) {
beta_s[tid] = beta_bh[chunk_start + tid];
gcum[tid] = g_bh[chunk_start + tid];
}
__syncthreads();
// === Phase 1b: Parallel prefix sum of g (Hillis-Steele) ===
for (int stride = 1; stride < BT; stride <<= 1) {
float prev = 0.0f;
if (tid < chunk_len && (int)tid >= stride)
prev = gcum[tid - stride];
__syncthreads();
if (tid < chunk_len && (int)tid >= stride)
gcum[tid] += prev;
__syncthreads();
}
// === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i ===
// Only lower-triangular entries needed (strictly lower)
for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) {
int i = idx / chunk_len;
int j = idx % chunk_len;
if (j < i) {
float dot = 0.0f;
for (int d = 0; d < BK; d++) {
dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot);
}
kk_dot[i * BT + j] = dot;
}
}
__syncthreads();
// === Phase 3: Forward substitution (per V-column, in registers) ===
// Computes corrected delta values via triangular solve
for (int i = 0; i < chunk_len; i++) {
float v_i = v_bh[(chunk_start + i) * v_dim + v_idx];
float decay_i = expf(gcum[i]);
float beta_i = beta_s[i];
// Inter-chunk contribution: state @ k[i] with decay
float kv_mem = 0.0f;
#pragma unroll
for (int d = 0; d < BK; d++) {
kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem);
}
float rhs = beta_i * (v_i - kv_mem);
// Subtract lower-triangular contributions (intra-chunk)
for (int j = 0; j < i; j++) {
float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]);
rhs -= a_ij * delta[j];
}
delta[i] = rhs;
}
// === Phase 4: Output computation (per V-column) ===
for (int i = 0; i < chunk_len; i++) {
// Cooperatively load q[i] into shared
for (int j = tid; j < BK; j += BV) {
q_buf[j] = q_bh[(chunk_start + i) * BK + j];
}
__syncthreads();
float decay_i = expf(gcum[i]);
// Inter-chunk contribution: q[i] @ (state * decay)
float o_val = 0.0f;
#pragma unroll
for (int d = 0; d < BK; d++) {
o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val);
}
// Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] *
// exp(gcum[i] - gcum[j])
for (int j = 0; j <= i; j++) {
float qk_dot = 0.0f;
for (int d = 0; d < BK; d++) {
qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot);
}
o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]);
}
out_bh[(chunk_start + i) * v_dim + v_idx] = o_val;
__syncthreads();
}
// === Phase 5: State update for next chunk ===
float g_total = gcum[chunk_len - 1];
#pragma unroll
for (int d = 0; d < BK; d++) {
float s_new = s[d] * expf(g_total);
for (int t = 0; t < chunk_len; t++) {
s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]);
}
s[d] = s_new;
}
__syncthreads();
}
// Write final state back
#pragma unroll
for (int j = 0; j < BK; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
extern "C" void chunked_gated_delta_rule_recurrence(
const float *q, const float *k, const float *v, const float *g,
const float *beta, float *state, float *output, int bh, int seq_len,
int k_dim, int v_dim, int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
if (k_dim == 128) {
constexpr int BT = 64;
constexpr int BK = 128;
constexpr int BV = 64;
// Shared memory: BT*BK + BT*BT + BT + BT + BK floats
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
// Request extended shared memory
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem);
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, v_dim);
} else if (k_dim == 64) {
constexpr int BT = 64;
constexpr int BK = 64;
constexpr int BV = 64;
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem);
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, v_dim);
} else {
// Fallback: use the sequential kernel for unsupported k_dim
gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len,
k_dim, v_dim, stream);
}
}
// ============================================================================
// Kernel 2a: causal_conv1d_update (decode path, single step)
//
// Each thread handles one channel: shift conv_state left by 1,
// insert new value, dot product with weight, apply SiLU.
//
// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
// conv_state: [B, conv_dim, kernel_size] (in/out)
// output: [B, conv_dim, 1]
// ============================================================================
template <typename T>
__global__ void causal_conv1d_update_kernel(
const T *__restrict__ x, // [B, conv_dim, 1]
const T *__restrict__ weight, // [conv_dim, kernel_size]
T *__restrict__ conv_state, // [B, conv_dim, kernel_size]
T *__restrict__ output, // [B, conv_dim, 1]
int batch_size, int conv_dim, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int b = blockIdx.y;
if (ch >= conv_dim || b >= batch_size)
return;
// Pointer to this batch/channel's conv state
T *cs = conv_state + (b * conv_dim + ch) * kernel_size;
const T *w = weight + ch * kernel_size;
// Shift state left by 1
for (int i = 0; i < kernel_size - 1; i++) {
cs[i] = cs[i + 1];
}
// Insert new value
cs[kernel_size - 1] = x[b * conv_dim + ch];
// Dot product with weight
float acc = 0.0f;
for (int i = 0; i < kernel_size; i++) {
acc += (float)cs[i] * (float)w[i];
}
// SiLU activation: x * sigmoid(x)
float sig = 1.0f / (1.0f + expf(-acc));
float result = acc * sig;
output[b * conv_dim + ch] = (T)result;
}
extern "C" void causal_conv1d_update(const void *x, const void *weight,
void *conv_state, void *output,
int batch_size, int conv_dim,
int kernel_size, int dtype,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
dim3 block(256);
dim3 grid((conv_dim + 255) / 256, batch_size);
if (dtype == 0) {
// f16
causal_conv1d_update_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)x, (const __half *)weight, (__half *)conv_state,
(__half *)output, batch_size, conv_dim, kernel_size);
} else {
// bf16
causal_conv1d_update_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
(__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size,
conv_dim, kernel_size);
}
}
// ============================================================================
// Kernel 2b: causal_conv1d_full (prefill path)
//
// Each thread handles one (channel, position): causal window with
// zero-padding, dot product with weight, SiLU.
// A second pass writes the conv_state from the last kernel_size positions.
//
// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S]
// ============================================================================
template <typename T>
__global__ void causal_conv1d_full_kernel(
const T *__restrict__ x, // [B, conv_dim, S]
const T *__restrict__ weight, // [conv_dim, kernel_size]
T *__restrict__ output, // [B, conv_dim, S]
int batch_size, int conv_dim, int seq_len, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int pos = blockIdx.y;
const int b = blockIdx.z;
if (ch >= conv_dim || pos >= seq_len || b >= batch_size)
return;
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
const T *w = weight + ch * kernel_size;
// Causal convolution: sum over kernel_size window ending at pos
float acc = 0.0f;
for (int i = 0; i < kernel_size; i++) {
int src_pos = pos - (kernel_size - 1) + i;
float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f;
acc += x_val * (float)w[i];
}
// SiLU
float sig = 1.0f / (1.0f + expf(-acc));
float result = acc * sig;
output[(b * conv_dim + ch) * seq_len + pos] = (T)result;
}
template <typename T>
__global__ void save_conv_state_kernel(
const T *__restrict__ x, // [B, conv_dim, S]
T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size]
int batch_size, int conv_dim, int seq_len, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int b = blockIdx.y;
if (ch >= conv_dim || b >= batch_size)
return;
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size;
// Save last kernel_size positions (zero-pad if seq_len < kernel_size)
int pad = kernel_size - seq_len;
for (int i = 0; i < kernel_size; i++) {
if (i < pad) {
cs[i] = (T)0.0f;
} else {
cs[i] = x_bch[seq_len - kernel_size + i];
}
}
}
extern "C" void causal_conv1d_full(const void *x, const void *weight,
void *conv_state_out, void *output,
int batch_size, int conv_dim, int seq_len,
int kernel_size, int dtype, int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
// Main convolution kernel
dim3 block(256);
dim3 grid((conv_dim + 255) / 256, seq_len, batch_size);
if (dtype == 0) {
causal_conv1d_full_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)x, (const __half *)weight, (__half *)output, batch_size,
conv_dim, seq_len, kernel_size);
// Save conv state
dim3 grid2((conv_dim + 255) / 256, batch_size);
save_conv_state_kernel<__half><<<grid2, block, 0, custream>>>(
(const __half *)x, (__half *)conv_state_out, batch_size, conv_dim,
seq_len, kernel_size);
} else {
causal_conv1d_full_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
(__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size);
dim3 grid2((conv_dim + 255) / 256, batch_size);
save_conv_state_kernel<__nv_bfloat16><<<grid2, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size,
conv_dim, seq_len, kernel_size);
}
}
// ============================================================================
// Kernel 3: fused_gdn_gating
//
// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
// a_log and dt_bias are per-head (broadcast over batch*seq).
//
// b, a: [total] a_log, dt_bias: [num_heads]
// beta_out, g_out: [total]
// ============================================================================
template <typename T>
__global__ void
fused_gdn_gating_kernel(const T *__restrict__ b, // [total]
const T *__restrict__ a, // [total]
const float *__restrict__ a_log, // [num_heads]
const float *__restrict__ dt_bias, // [num_heads]
T *__restrict__ beta_out, // [total]
T *__restrict__ g_out, // [total]
int total_elements, int num_heads) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_elements)
return;
// Head index: elements are laid out as [..., num_heads]
int head_idx = idx % num_heads;
// beta = sigmoid(b)
float b_val = (float)b[idx];
float beta = 1.0f / (1.0f + expf(-b_val));
// g = -exp(a_log) * softplus(a + dt_bias)
float a_val = (float)a[idx];
float a_log_val = a_log[head_idx];
float dt_bias_val = dt_bias[head_idx];
float sp_input = a_val + dt_bias_val;
float softplus_val = logf(1.0f + expf(sp_input));
float g_val = -expf(a_log_val) * softplus_val;
beta_out[idx] = (T)beta;
g_out[idx] = (T)g_val;
}
extern "C" void fused_gdn_gating(const void *b, const void *a,
const float *a_log, const float *dt_bias,
void *beta_out, void *g_out,
int total_elements, int num_heads, int dtype,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
dim3 block(256);
dim3 grid((total_elements + 255) / 256);
if (dtype == 0) {
fused_gdn_gating_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)b, (const __half *)a, a_log, dt_bias,
(__half *)beta_out, (__half *)g_out, total_elements, num_heads);
} else {
fused_gdn_gating_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias,
(__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements,
num_heads);
}
}

View File

@@ -1,486 +0,0 @@
//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`.
//!
//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in
//! this file are this header comment — the FFI path module name is
//! `crate::cuda::ffi`, identical to upstream's layout.
#![allow(clippy::cast_possible_truncation)]
use candle_core::{Result, Tensor};
#[cfg(feature = "cuda")]
use candle_core::DType;
/// CUDA-accelerated gated delta rule recurrence.
///
/// Inputs (all contiguous, f32):
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
/// state: [BH, K, V] (mutated in place)
///
/// Returns: output [BH, S, V]
#[cfg(feature = "cuda")]
pub fn gated_delta_rule_recurrence_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: &mut Tensor,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
let (bh, seq_len, k_dim) = q.dims3()?;
let v_dim = v.dim(2)?;
let dev = q.device().as_cuda_device()?;
let (q_s, q_l) = q.storage_and_layout();
let q_s = match &*q_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("q must be a cuda tensor"),
};
let q_offset = q_l.start_offset();
let (k_s, k_l) = k.storage_and_layout();
let k_s = match &*k_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("k must be a cuda tensor"),
};
let k_offset = k_l.start_offset();
let (v_s, v_l) = v.storage_and_layout();
let v_s = match &*v_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("v must be a cuda tensor"),
};
let v_offset = v_l.start_offset();
let (g_s, g_l) = g.storage_and_layout();
let g_s = match &*g_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("g must be a cuda tensor"),
};
let g_offset = g_l.start_offset();
let (beta_s, beta_l) = beta.storage_and_layout();
let beta_s = match &*beta_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("beta must be a cuda tensor"),
};
let beta_offset = beta_l.start_offset();
let (state_s, state_l) = state.storage_and_layout();
let state_s = match &*state_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("state must be a cuda tensor"),
};
let state_offset = state_l.start_offset();
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::gated_delta_rule_recurrence(
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
bh as i32,
seq_len as i32,
k_dim as i32,
v_dim as i32,
stream,
);
}
// The kernel wrote state in-place via the raw pointer; rewrap
// (state tensor's underlying CudaSlice was modified directly)
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
Ok(Tensor::from((
candle::Storage::Cuda(output_storage),
(bh, seq_len, v_dim),
)))
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn gated_delta_rule_recurrence_cuda(
_q: &Tensor,
_k: &Tensor,
_v: &Tensor,
_g: &Tensor,
_beta: &Tensor,
_state: &mut Tensor,
) -> Result<Tensor> {
candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature")
}
/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization).
///
/// Processes prefill tokens in 64-token chunks instead of one at a time.
/// Same interface as `gated_delta_rule_recurrence_cuda`.
///
/// Inputs (all contiguous, f32):
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
/// state: [BH, K, V] (mutated in place)
///
/// Returns: output [BH, S, V]
#[cfg(feature = "cuda")]
pub fn chunked_gated_delta_rule_recurrence_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: &mut Tensor,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
let (bh, seq_len, k_dim) = q.dims3()?;
let v_dim = v.dim(2)?;
let dev = q.device().as_cuda_device()?;
let (q_s, q_l) = q.storage_and_layout();
let q_s = match &*q_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("q must be a cuda tensor"),
};
let q_offset = q_l.start_offset();
let (k_s, k_l) = k.storage_and_layout();
let k_s = match &*k_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("k must be a cuda tensor"),
};
let k_offset = k_l.start_offset();
let (v_s, v_l) = v.storage_and_layout();
let v_s = match &*v_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("v must be a cuda tensor"),
};
let v_offset = v_l.start_offset();
let (g_s, g_l) = g.storage_and_layout();
let g_s = match &*g_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("g must be a cuda tensor"),
};
let g_offset = g_l.start_offset();
let (beta_s, beta_l) = beta.storage_and_layout();
let beta_s = match &*beta_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("beta must be a cuda tensor"),
};
let beta_offset = beta_l.start_offset();
let (state_s, state_l) = state.storage_and_layout();
let state_s = match &*state_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("state must be a cuda tensor"),
};
let state_offset = state_l.start_offset();
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::chunked_gated_delta_rule_recurrence(
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
bh as i32,
seq_len as i32,
k_dim as i32,
v_dim as i32,
stream,
);
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
Ok(Tensor::from((
candle::Storage::Cuda(output_storage),
(bh, seq_len, v_dim),
)))
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn chunked_gated_delta_rule_recurrence_cuda(
_q: &Tensor,
_k: &Tensor,
_v: &Tensor,
_g: &Tensor,
_beta: &Tensor,
_state: &mut Tensor,
) -> Result<Tensor> {
candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature")
}
/// CUDA-accelerated causal conv1d (both update and full paths).
///
/// For update (is_update=true):
/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update)
/// Returns: (output [B, conv_dim, 1], updated conv_state)
///
/// For full (is_update=false):
/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size])
#[cfg(feature = "cuda")]
pub fn causal_conv1d_cuda(
x: &Tensor,
weight: &Tensor,
conv_state: &Tensor,
kernel_size: usize,
is_update: bool,
) -> Result<(Tensor, Tensor)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use core::ffi::c_void;
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
x: &Tensor,
weight: &Tensor,
conv_state: &Tensor,
kernel_size: usize,
is_update: bool,
dtype_code: i32,
) -> Result<(Tensor, Tensor)> {
let dev = x.device().as_cuda_device()?;
let (batch_size, conv_dim, seq_len) = x.dims3()?;
let (x_s, x_l) = x.storage_and_layout();
let x_s = match &*x_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("x must be a cuda tensor"),
};
let x_offset = x_l.start_offset();
let (w_s, w_l) = weight.storage_and_layout();
let w_s = match &*w_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("weight must be a cuda tensor"),
};
let w_offset = w_l.start_offset();
let stream = dev.cuda_stream().cu_stream() as i64;
if is_update {
// Clone conv_state so the kernel can mutate it in place
let conv_state_new = conv_state.clone();
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim) }?;
// Scope the borrow of conv_state_new so we can move it later
{
let (cs_s, cs_l) = conv_state_new.storage_and_layout();
let cs_s = match &*cs_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("conv_state must be a cuda tensor"),
};
let cs_offset = cs_l.start_offset();
unsafe {
crate::cuda::ffi::causal_conv1d_update(
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void,
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
batch_size as i32,
conv_dim as i32,
kernel_size as i32,
dtype_code,
stream,
);
}
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
let output = Tensor::from((
candle::Storage::Cuda(output_storage),
(batch_size, conv_dim, 1usize),
));
Ok((output, conv_state_new))
} else {
// Full path: allocate new conv_state and output
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * seq_len) }?;
let cs_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * kernel_size) }?;
unsafe {
crate::cuda::ffi::causal_conv1d_full(
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void,
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
batch_size as i32,
conv_dim as i32,
seq_len as i32,
kernel_size as i32,
dtype_code,
stream,
);
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
let output = Tensor::from((
candle::Storage::Cuda(output_storage),
(batch_size, conv_dim, seq_len),
));
let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone());
let new_conv_state = Tensor::from((
candle::Storage::Cuda(cs_storage),
(batch_size, conv_dim, kernel_size),
));
Ok((output, new_conv_state))
}
}
match x.dtype() {
DType::F16 => cuda_fwd::<half::f16>(x, weight, conv_state, kernel_size, is_update, 0),
DType::BF16 => cuda_fwd::<half::bf16>(x, weight, conv_state, kernel_size, is_update, 1),
other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn causal_conv1d_cuda(
_x: &Tensor,
_weight: &Tensor,
_conv_state: &Tensor,
_kernel_size: usize,
_is_update: bool,
) -> Result<(Tensor, Tensor)> {
candle_core::bail!("causal_conv1d_cuda requires the cuda feature")
}
/// CUDA-accelerated fused GDN gating computation.
///
/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
///
/// b, a: [total_elements] in f16/bf16
/// a_log, dt_bias: [num_heads] in f32
///
/// Returns: (beta, g) in original dtype
#[cfg(feature = "cuda")]
pub fn fused_gdn_gating_cuda(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> Result<(Tensor, Tensor)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use core::ffi::c_void;
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
dtype_code: i32,
) -> Result<(Tensor, Tensor)> {
let total_elements = b.elem_count();
let num_heads = a_log.elem_count();
let shape = b.shape().clone();
let dev = b.device().as_cuda_device()?;
let (b_s, b_l) = b.storage_and_layout();
let b_s = match &*b_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("b must be a cuda tensor"),
};
let b_offset = b_l.start_offset();
let (a_s, a_l) = a.storage_and_layout();
let a_s = match &*a_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("a must be a cuda tensor"),
};
let a_offset = a_l.start_offset();
let (alog_s, alog_l) = a_log.storage_and_layout();
let alog_s = match &*alog_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("a_log must be a cuda tensor"),
};
let alog_offset = alog_l.start_offset();
let (dtb_s, dtb_l) = dt_bias.storage_and_layout();
let dtb_s = match &*dtb_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("dt_bias must be a cuda tensor"),
};
let dtb_offset = dtb_l.start_offset();
let beta_buf = unsafe { dev.alloc::<T>(total_elements) }?;
let g_buf = unsafe { dev.alloc::<T>(total_elements) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::fused_gdn_gating(
b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void,
a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void,
alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32,
dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32,
beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void,
g_buf.device_ptr(g_buf.stream()).0 as *mut c_void,
total_elements as i32,
num_heads as i32,
dtype_code,
stream,
);
}
let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone());
let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone()));
let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone());
let g = Tensor::from((candle::Storage::Cuda(g_storage), shape));
Ok((beta, g))
}
match b.dtype() {
DType::F16 => cuda_fwd::<half::f16>(b, a, a_log, dt_bias, 0),
DType::BF16 => cuda_fwd::<half::bf16>(b, a, a_log, dt_bias, 1),
other => candle_core::bail!(
"fused_gdn_gating_cuda only supports f16/bf16, got {:?}",
other
),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn fused_gdn_gating_cuda(
_b: &Tensor,
_a: &Tensor,
_a_log: &Tensor,
_dt_bias: &Tensor,
) -> Result<(Tensor, Tensor)> {
candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature")
}

View File

@@ -1,15 +0,0 @@
//! CUDA kernels and their Rust wrappers.
//!
//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`)
//! inference performance — the Gated DeltaNet kernels ported from
//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu`
//! file alongside this module; `build.rs` compiles them all into a
//! static lib via `cudaforge` and links it under the `cuda` feature.
//!
//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM,
//! etc.) they land here in their own `.cu` + `.rs` pairs.
#[cfg(feature = "cuda")]
pub mod ffi;
#[cfg(feature = "cuda")]
pub mod gdn;

View File

@@ -1,23 +0,0 @@
//! Custom architecture implementations.
//!
//! When candle-transformers ships a model family unchanged
//! (`models::llama`, `models::qwen3`, `models::qwen3_moe`, etc.), the
//! handler in `harness/candle.rs` just wraps the upstream type in a
//! `ModelArch` variant.
//!
//! When candle has nothing for the architecture and we have to write
//! it from scratch — Qwen3-Next / Qwen3.6 (`qwen3_5`) being the
//! motivating example — the implementation lands here, one file per
//! architecture.
//!
//! Each architecture module is expected to expose:
//! - A `Config` type deserialised from the model's `config.json`
//! (some architectures nest the real hyperparams under `text_config`,
//! in which case the module owns the unwrapping).
//! - A `ForCausalLM` struct with `new`, `forward(&mut self, x, offset)
//! -> Result<Tensor>`, and `clear_kv_cache(&mut self)`.
//!
//! TP-aware analogues live in `harness/tp/tp_<family>.rs` and follow
//! the pattern set by `tp_qwen3.rs`.
pub mod qwen3_5;

View File

@@ -1,117 +0,0 @@
//! Qwen3-Next decoder layer.
//!
//! Standard pre-norm transformer block (LN → attention → residual →
//! LN → MLP → residual) where the attention slot dispatches on the
//! per-layer `layer_types[i]` value in the config:
//!
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
//! gate + RoPE + KV cache).
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
//! causal conv + per-head state).
//!
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
//! linear_attention. `full_attention_interval` in the config is a
//! hint; `layer_types` is authoritative.
use anyhow::Result;
use candle_core::{Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use std::sync::Arc;
use super::TextConfig;
use super::full_attn::Qwen3_5Attention;
use super::linear_attn::GatedDeltaNet;
use super::mlp::Qwen3_5MLP;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
/// One of the two attention flavours sitting in a decoder layer's
/// attention slot. Full-attention layers need the rotary table and
/// take an attention mask; linear-attention layers carry their own
/// recurrent state and ignore the mask.
enum AttentionKind {
Full(Qwen3_5Attention),
Linear(GatedDeltaNet),
}
pub struct Qwen3_5DecoderLayer {
input_layernorm: Qwen3_5RmsNorm,
post_attention_layernorm: Qwen3_5RmsNorm,
mlp: Qwen3_5MLP,
attention: AttentionKind,
}
impl Qwen3_5DecoderLayer {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
layer_idx: usize,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let layer_type = cfg
.layer_types
.get(layer_idx)
.map(String::as_str)
.ok_or_else(|| {
anyhow::anyhow!(
"layer_types[{layer_idx}] missing (have {} entries)",
cfg.layer_types.len()
)
})?;
let attention = match layer_type {
"full_attention" => {
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
}
"linear_attention" => {
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
}
other => anyhow::bail!(
"unknown layer_type '{other}' for layer {layer_idx} (expected \
'full_attention' or 'linear_attention')"
),
};
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
let input_layernorm =
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let post_attention_layernorm = Qwen3_5RmsNorm::load(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
input_layernorm,
post_attention_layernorm,
mlp,
attention,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let h = self.input_layernorm.forward(x)?;
let attn_out = match &mut self.attention {
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
// Linear attention ignores attn_mask + offset; its causal
// structure is baked into the recurrent state lifecycle.
AttentionKind::Linear(net) => net.forward(&h)?,
};
let x = (x + attn_out)?;
let h2 = self.post_attention_layernorm.forward(&x)?;
let h2 = self.mlp.forward(&h2)?;
x + h2
}
pub fn clear_kv_cache(&mut self) {
match &mut self.attention {
AttentionKind::Full(attn) => attn.clear_kv_cache(),
AttentionKind::Linear(net) => net.clear_kv_cache(),
}
}
}

View File

@@ -1,179 +0,0 @@
//! Qwen3-Next's `full_attention` layer.
//!
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
//!
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
//! to `num_heads * head_dim * 2`. The second half is reshaped to
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
//! attention output is pointwise-multiplied by this gate before
//! `o_proj`. Effectively a per-head per-position attenuation on
//! the attention output.
//!
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
//! checkpoints expect the `(1 + w)` form.
//!
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
//! `rope::RotaryEmbedding`), and the usual causal mask.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::kv_cache::ConcatKvCache;
use candle_nn::var_builder::ShardedVarBuilder;
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;
use super::TextConfig;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
pub struct Qwen3_5Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: Qwen3_5RmsNorm,
k_norm: Qwen3_5RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary: Arc<RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl Qwen3_5Attention {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let head_dim = cfg.head_dim;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
anyhow::bail!(
"num_attention_heads ({num_heads}) must be a positive multiple of \
num_key_value_heads ({num_kv_heads})"
);
}
let num_kv_groups = num_heads / num_kv_heads;
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
// the gate (see attn_output_gate notes above).
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
let hidden_size = head_dim * num_heads;
let kv_cache = ConcatKvCache::new(2);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size,
rotary,
kv_cache,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
// 1. q_proj — widened output, split into (query, gate).
let q_raw = self
.q_proj
.forward(x)?
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
let q = q_raw.narrow(3, 0, self.head_dim)?;
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
// Flatten the gate's head dim back into hidden_size for the
// post-attention pointwise multiply.
let gate = gate
.contiguous()?
.reshape((b, l, self.num_heads * self.head_dim))?;
// 2. q_norm + k_norm + reshape to (B, H, L, D).
let q = self.q_norm.forward(&q.contiguous()?)?;
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
let k = self
.k_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
let k = self.k_norm.forward(&k.contiguous()?)?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = self
.v_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
// 3. RoPE on q, k.
let (q, k) = self.rotary.apply(&q, &k, offset)?;
// 4. KV cache.
let (k, v) = self.kv_cache.append(&k, &v)?;
// 5. GQA repeat (cheap shape op).
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
// 6. Scaled dot-product + causal mask.
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?; // (B, H, L, D)
// 7. Reshape back, apply the output gate, project.
let ctx = ctx
.transpose(1, 2)?
.contiguous()?
.reshape((b, l, self.hidden_size))?;
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
let gated = (ctx * gate_sig)?;
self.o_proj.forward(&gated)
}
pub fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

View File

@@ -1,793 +0,0 @@
//! Qwen3-Next's `linear_attention` layer: Gated DeltaNet.
//!
//! The recurrent linear-attention block that occupies 3 out of every 4
//! decoder layers in Qwen3.6 (`layer_types[i] == "linear_attention"`).
//! Implemented against the reference Python in
//! `huggingface/transformers/src/transformers/models/qwen3_5/modeling_qwen3_5.py`
//! (class `Qwen3_5GatedDeltaNet`).
//!
//! ## Block structure
//!
//! ```text
//! x ── in_proj_qkv ── transpose ─► (B, conv_dim, L)
//! │
//! ┌──────────────── conv_state ──┤ prepend cached state (decode)
//! ▼
//! depthwise causal Conv1d (k=4) → SiLU
//! │
//! └─ split → q (k_dim), k (k_dim), v (v_dim) ─► per-head reshape
//!
//! x ── in_proj_z ────────────────► z (gate for the output RMSNorm)
//! x ── in_proj_b ── sigmoid ─────► beta (per-head per-token update rate)
//! x ── in_proj_a ── softplus ────► g (decay; see eqn below)
//!
//! g = -exp(A_log) * softplus(a + dt_bias) # discretisation
//! beta = sigmoid(b)
//!
//! (q, k) ─── L2norm ─── delta rule loop ──── core_attn_out
//! (per-token, per-head):
//! state *= exp(g_t)
//! mem = state^T · k_t
//! delta = (v_t - mem) * beta_t
//! state += outer(k_t, delta)
//! out_t = state^T · q_t
//!
//! core_attn_out ── RMSNormGated(z) ── reshape ── out_proj ── y
//! ```
//!
//! ## State
//!
//! Two tensors persist across decode steps:
//! - `conv_state`: `(B, conv_dim, conv_kernel_size)` — left-padded
//! tail of the input to the depthwise conv, so the next causal
//! window has the right left-context.
//! - `recurrent_state`: `(B, num_v_heads, head_k_dim, head_v_dim)` —
//! the delta-rule outer-product memory.
//!
//! Both are cleared via [`GatedDeltaNet::clear_kv_cache`] at the start
//! of every new request.
//!
//! ## Performance note
//!
//! This impl is the **recurrent** delta-rule for both prefill and
//! decode — i.e. the algorithm in `torch_recurrent_gated_delta_rule`.
//! Correctness-first. The chunked algorithm (chunk_size=64) in
//! `torch_chunk_gated_delta_rule` is a perf optimisation for long
//! prefill; can be added later without changing the surface.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
#[cfg(test)]
use super::RopeParameters;
use super::TextConfig;
use super::rmsnorm::{Qwen3_5RmsNormGated, l2norm};
/// Per-rank, per-layer state for the linear-attention block.
///
/// `conv_state` is left-padded with zeros on first use; `recurrent_state`
/// is initialised lazily to zeros once we know the batch size.
#[derive(Default)]
pub struct GatedDeltaNetState {
pub conv_state: Option<Tensor>,
pub recurrent_state: Option<Tensor>,
}
pub struct GatedDeltaNet {
// Projections.
in_proj_qkv: Linear,
in_proj_z: Linear,
in_proj_b: Linear,
in_proj_a: Linear,
out_proj: Linear,
// Depthwise causal Conv1d weight; shape (conv_dim, 1, kernel_size).
// No bias (Python sets bias=False).
conv1d_weight: Tensor,
// Per-head discretisation params.
dt_bias: Tensor,
a_log: Tensor,
// Output norm + gate.
norm: Qwen3_5RmsNormGated,
// Shape hyperparams (cached for forward).
num_v_heads: usize,
num_k_heads: usize,
head_k_dim: usize,
head_v_dim: usize,
key_dim: usize,
value_dim: usize,
conv_dim: usize,
conv_kernel_size: usize,
// Recurrent state held inline. Each request resets via
// `clear_kv_cache`; otherwise the state persists across forwards
// and the per-token offset advances naturally.
state: GatedDeltaNetState,
}
impl GatedDeltaNet {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let num_v_heads = cfg.linear_num_value_heads;
let num_k_heads = cfg.linear_num_key_heads;
let head_k_dim = cfg.linear_key_head_dim;
let head_v_dim = cfg.linear_value_head_dim;
let conv_kernel_size = cfg.linear_conv_kernel_dim;
if num_v_heads == 0 || num_k_heads == 0 {
anyhow::bail!(
"Qwen3-Next linear_num_*_heads must be set; got v={num_v_heads}, k={num_k_heads}"
);
}
if !num_v_heads.is_multiple_of(num_k_heads) {
anyhow::bail!(
"linear_num_value_heads ({num_v_heads}) must be a multiple of \
linear_num_key_heads ({num_k_heads}) for GQA-style head expansion"
);
}
let key_dim = head_k_dim * num_k_heads;
let value_dim = head_v_dim * num_v_heads;
let conv_dim = key_dim * 2 + value_dim;
// ----- Linear projections (all `bias=False` in the reference). -----
let in_proj_qkv = load_linear_no_bias(vb, "in_proj_qkv", cfg.hidden_size, conv_dim)?;
let in_proj_z = load_linear_no_bias(vb, "in_proj_z", cfg.hidden_size, value_dim)?;
let in_proj_b = load_linear_no_bias(vb, "in_proj_b", cfg.hidden_size, num_v_heads)?;
let in_proj_a = load_linear_no_bias(vb, "in_proj_a", cfg.hidden_size, num_v_heads)?;
let out_proj = load_linear_no_bias(vb, "out_proj", value_dim, cfg.hidden_size)?;
// ----- Conv1d weight (depthwise, bias=False). -----
let conv1d_weight = vb
.pp("conv1d")
.get((conv_dim, 1, conv_kernel_size), "weight")
.with_context(|| format!("load '{}/conv1d/weight'", vb.prefix()))?;
// ----- dt_bias + A_log: per-head 1D params. -----
let dt_bias = vb
.get(num_v_heads, "dt_bias")
.with_context(|| format!("load '{}/dt_bias'", vb.prefix()))?;
let a_log = vb
.get(num_v_heads, "A_log")
.with_context(|| format!("load '{}/A_log'", vb.prefix()))?;
// ----- Output gated RMSNorm (per-head_v_dim). -----
let norm = Qwen3_5RmsNormGated::load(&vb.pp("norm"), head_v_dim, cfg.rms_norm_eps)?;
Ok(Self {
in_proj_qkv,
in_proj_z,
in_proj_b,
in_proj_a,
out_proj,
conv1d_weight,
dt_bias,
a_log,
norm,
num_v_heads,
num_k_heads,
head_k_dim,
head_v_dim,
key_dim,
value_dim,
conv_dim,
conv_kernel_size,
state: GatedDeltaNetState::default(),
})
}
pub fn clear_kv_cache(&mut self) {
self.state = GatedDeltaNetState::default();
}
/// `x` shape: `(B, L, hidden_size)`. Returns the same shape.
pub fn forward(&mut self, x: &Tensor) -> candle_core::Result<Tensor> {
let (batch_size, seq_len, _) = x.dims3()?;
let dtype = x.dtype();
let device = x.device().clone();
// ----- Projections. -----
// mixed_qkv: (B, L, conv_dim)
let mixed_qkv = self.in_proj_qkv.forward(x)?;
// (B, conv_dim, L) for the conv1d.
let mixed_qkv_chw = mixed_qkv.transpose(1, 2)?.contiguous()?;
// z: (B, L, value_dim) → (B, L, num_v_heads, head_v_dim)
let z = self.in_proj_z.forward(x)?.reshape((
batch_size,
seq_len,
self.num_v_heads,
self.head_v_dim,
))?;
// b, a: (B, L, num_v_heads)
let b = self.in_proj_b.forward(x)?;
let a = self.in_proj_a.forward(x)?;
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
// Dispatches to a cuda kernel that fuses conv1d + silu when
// available; falls back to candle's `conv1d` + `silu` on cpu.
let (conv_out, new_state) = run_causal_conv1d(
&mixed_qkv_chw,
&self.conv1d_weight,
self.state.conv_state.take(),
batch_size,
self.conv_dim,
seq_len,
self.conv_kernel_size,
)?;
self.state.conv_state = Some(new_state);
// Back to (B, L, conv_dim).
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
// ----- Split into q, k, v. -----
let q = mixed_qkv.narrow(2, 0, self.key_dim)?;
let k = mixed_qkv.narrow(2, self.key_dim, self.key_dim)?;
let v = mixed_qkv.narrow(2, 2 * self.key_dim, self.value_dim)?;
let q = q.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
let k = k.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?;
// ----- beta + g (per-head, per-token gates). -----
// Fused on cuda; per-op Rust on cpu. Both paths produce:
// beta = sigmoid(b)
// g = -exp(A_log) * softplus(a + dt_bias)
let (beta, g) = run_fused_gating(&b, &a, &self.a_log, &self.dt_bias)?;
// ----- GQA-style key expansion if num_v_heads > num_k_heads. -----
let (q, k) = if self.num_v_heads > self.num_k_heads {
let rep = self.num_v_heads / self.num_k_heads;
(
repeat_interleave(&q, rep, 2)?,
repeat_interleave(&k, rep, 2)?,
)
} else {
(q, k)
};
// ----- L2-norm on q, k (use_qk_l2norm_in_kernel=True in ref). -----
let q = l2norm(&q, 1e-6)?;
let k = l2norm(&k, 1e-6)?;
// ----- Recurrent delta rule. -----
// Inputs: q, k (B, L, H, D_k); v (B, L, H, D_v); g (B, L, H); beta (B, L, H).
// The reference transposes to (B, H, L, D) before the loop. We
// do the same — it makes per-token indexing trivial.
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D_k)
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?; // (B, H, L, D_v)
let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L)
let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L)
// Pre-scale q by 1/sqrt(D_k) once. Everything goes to f32 here
// since the delta rule mixes broadcast_mul ops that candle won't
// accept across mixed dtypes. On the cuda gating path both beta
// and g come back in model dtype; on the cpu path g is already
// f32 — both casts are cheap idempotent ops.
let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt();
let q = (q.to_dtype(candle_core::DType::F32)? * scale)?;
let k = k.to_dtype(candle_core::DType::F32)?;
let v = v.to_dtype(candle_core::DType::F32)?;
let g = g.to_dtype(candle_core::DType::F32)?;
let beta = beta.to_dtype(candle_core::DType::F32)?;
// Initialise the recurrent state from cache or zeros.
let state_init = match self.state.recurrent_state.take() {
Some(s) => s.to_dtype(candle_core::DType::F32)?,
None => Tensor::zeros(
(
batch_size,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
),
candle_core::DType::F32,
&device,
)?,
};
// The delta-rule body: cuda-accelerated `gated_delta_rule_recurrence`
// kernel when we have a cuda device + the kernels are linked in,
// pure-Rust per-token fallback otherwise.
let (core_attn_out, new_state) = run_delta_rule(
&q,
&k,
&v,
&g,
&beta,
state_init,
batch_size,
self.num_v_heads,
seq_len,
self.head_k_dim,
self.head_v_dim,
)?;
// Stash the updated recurrent state for the next call.
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
let core_attn_out = core_attn_out.to_dtype(dtype)?;
let core_attn_flat =
core_attn_out.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
let z_flat = z.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
// RMSNormGated: (out * silu(z) * weight) with the norm.
let normed = self.norm.forward(&core_attn_flat, &z_flat)?;
let normed = normed.reshape((batch_size, seq_len, self.num_v_heads * self.head_v_dim))?;
// Output projection: (B, L, value_dim) → (B, L, hidden_size).
self.out_proj.forward(&normed)
}
}
/// Run the per-token delta-rule recurrence.
///
/// `q`, `k`: `(B, H, L, D_k)` (F32). `v`: `(B, H, L, D_v)`. `g`,
/// `beta`: `(B, H, L)`. `state`: `(B, H, D_k, D_v)`.
///
/// Returns `(core_attn_out: (B, H, L, D_v), state: (B, H, D_k, D_v))`,
/// both F32. Caller is responsible for cast back to model dtype.
///
/// Cuda path: dispatches to the `gated_delta_rule_recurrence` kernel
/// ported from `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/gdn.cu`.
/// All five inputs must be cuda f32 tensors. The kernel is V-tiled
/// with compile-time BK; one block per (V-tile, batch*head) and one
/// thread per V-column. Each thread holds BK state floats in
/// registers — eliminates the launch-overhead floor we hit with
/// candle's per-op dispatch (was ~12s/token on Qwen3.6-27B).
///
/// CPU path: pure-Rust per-token loop. Correct, slow.
#[allow(clippy::too_many_arguments)]
pub(crate) fn run_delta_rule(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: Tensor,
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_k_dim: usize,
head_v_dim: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
// Only dispatch to the kernel if the inputs are on a CUDA
// device — CPU tests fall back to the Rust loop below.
if q.device().is_cuda() {
return run_delta_rule_cuda(
q, k, v, g, beta, state, batch_size, num_heads, seq_len, head_k_dim, head_v_dim,
);
}
}
let _ = (batch_size, num_heads, head_k_dim, head_v_dim);
run_delta_rule_rust(q, k, v, g, beta, state, seq_len)
}
/// CUDA path. Flattens (B, H, ...) → (BH, ...) at the kernel boundary
/// (the kernel uses BH = batch*heads as its outer batch axis) and
/// reshapes the kernel's outputs back to (B, H, ...) for the caller.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
fn run_delta_rule_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: Tensor,
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_k_dim: usize,
head_v_dim: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let q_bh = q.flatten(0, 1)?.contiguous()?;
let k_bh = k.flatten(0, 1)?.contiguous()?;
let v_bh = v.flatten(0, 1)?.contiguous()?;
let g_bh = g.flatten(0, 1)?.contiguous()?;
let beta_bh = beta.flatten(0, 1)?.contiguous()?;
let mut state_bh = state.flatten(0, 1)?.contiguous()?;
// For long prefills, the chunked kernel (BT=64) processes a chunk
// of tokens at a time instead of one-by-one — same delta-rule math,
// far fewer block launches. Threshold matches mistralrs.
const CHUNK_THRESHOLD: usize = 64;
let output_bh = if seq_len >= CHUNK_THRESHOLD {
crate::cuda::gdn::chunked_gated_delta_rule_recurrence_cuda(
&q_bh,
&k_bh,
&v_bh,
&g_bh,
&beta_bh,
&mut state_bh,
)?
} else {
crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
&q_bh,
&k_bh,
&v_bh,
&g_bh,
&beta_bh,
&mut state_bh,
)?
};
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
Ok((core_attn_out, new_state))
}
#[allow(clippy::too_many_arguments)]
fn run_delta_rule_rust(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
mut state: Tensor,
seq_len: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
use candle_core::IndexOp;
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let q_t = q.i((.., .., t, ..))?;
let k_t = k.i((.., .., t, ..))?;
let v_t = v.i((.., .., t, ..))?;
let g_t = g.i((.., .., t))?;
let beta_t = beta.i((.., .., t))?;
let decay = g_t
.exp()?
.unsqueeze(candle_core::D::Minus1)?
.unsqueeze(candle_core::D::Minus1)?;
state = state.broadcast_mul(&decay)?;
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?;
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?;
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
let delta_row = delta.unsqueeze(2)?;
let outer = k_col.broadcast_mul(&delta_row)?;
state = (state + outer)?;
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?;
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
outputs.push(out_t.unsqueeze(2)?);
}
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
Ok((core_attn_out, state))
}
/// Depthwise causal conv1d + SiLU, with rolling `conv_state`.
///
/// `x`: `(B, conv_dim, L)` model dtype (f16/bf16 on cuda, anything on cpu).
/// `weight`: `(conv_dim, 1, kernel_size)` model dtype.
/// `conv_state`: `Some((B, conv_dim, kernel_size))` for decode continuation,
/// or `None` for fresh prefill.
///
/// Returns `(conv_out: (B, conv_dim, L), new_conv_state: (B, conv_dim, kernel_size))`.
/// SiLU is baked in.
///
/// Cuda path: dispatches to `causal_conv1d_update` (decode, seq_len=1 with
/// existing state) or `causal_conv1d_full` (prefill / first call), both
/// ported from mistralrs `gdn.cu`. Each kernel fuses the depthwise conv
/// and SiLU activation in one launch — that's ~4× fewer cuda launches per
/// linear-attention layer than the candle `conv1d` + `silu` combo.
///
/// CPU path: the original prepend-narrow-conv1d-silu sequence.
pub(crate) fn run_causal_conv1d(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
if x.device().is_cuda() {
return run_causal_conv1d_cuda(
x,
weight,
conv_state,
batch_size,
conv_dim,
seq_len,
conv_kernel_size,
);
}
}
run_causal_conv1d_rust(
x,
weight,
conv_state,
batch_size,
conv_dim,
seq_len,
conv_kernel_size,
)
}
#[cfg(feature = "cuda")]
fn run_causal_conv1d_cuda(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
// Kernel expects weight as (conv_dim, kernel_size) — squeeze the
// depthwise channel-multiplier dim.
let w = weight.squeeze(1)?.to_dtype(x.dtype())?.contiguous()?;
// Decode path: seq_len == 1 AND we have an existing conv_state.
// Otherwise (prefill or fresh-start decode), use the full path which
// zero-pads on the left internally.
if let Some(cs) = conv_state
&& seq_len == 1
{
let cs = cs.contiguous()?;
let (output, new_conv_state) =
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &cs, conv_kernel_size, true)?;
return Ok((output, new_conv_state));
}
// Prefill / fresh-start: the kernel ignores any prior conv_state and
// zero-pads. If we had a non-zero prior state and >1 input tokens
// (multi-turn continuation), we'd need to fall back to Rust. Match
// mistralrs's behaviour: fresh prefill always.
let device = x.device().clone();
let zeros_cs = Tensor::zeros((batch_size, conv_dim, conv_kernel_size), x.dtype(), &device)?;
let (output, new_conv_state) =
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &zeros_cs, conv_kernel_size, false)?;
Ok((output, new_conv_state))
}
/// Fused GDN gating: computes `beta = sigmoid(b)` and
/// `g = -exp(a_log) * softplus(a + dt_bias)` together.
///
/// `b`, `a`: `(B, L, num_heads)` model dtype.
/// `a_log`, `dt_bias`: `(num_heads,)` model dtype (cast to f32 internally).
///
/// Returns `(beta, g)` both in model dtype on the cuda path, both in f32
/// on the cpu fallback. The caller casts to f32 before the delta rule.
///
/// Cuda path: dispatches to `fused_gdn_gating_cuda` — one kernel
/// replaces sigmoid + neg(exp) + softplus + broadcast_mul (≈10 candle
/// launches per layer).
pub(crate) fn run_fused_gating(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
if b.device().is_cuda() {
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?.contiguous()?;
let dt_bias_f32 = dt_bias.to_dtype(candle_core::DType::F32)?.contiguous()?;
return crate::cuda::gdn::fused_gdn_gating_cuda(b, a, &a_log_f32, &dt_bias_f32);
}
}
run_fused_gating_rust(b, a, a_log, dt_bias)
}
fn run_fused_gating_rust(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> candle_core::Result<(Tensor, Tensor)> {
let beta = candle_nn::ops::sigmoid(b)?;
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?;
let neg_a_exp = a_log_f32.exp()?.neg()?;
let dt_b_f32 = dt_bias.to_dtype(candle_core::DType::F32)?;
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
let softplus_val = softplus(&a_plus_dt)?;
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
let g = neg_a_exp_b.broadcast_mul(&softplus_val)?;
Ok((beta, g))
}
fn run_causal_conv1d_rust(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let dtype = x.dtype();
let device = x.device().clone();
let prepended = match &conv_state {
Some(prev) => Tensor::cat(&[prev, x], 2)?,
None => x.clone(),
};
let prep_len = prepended.dims()[2];
let new_state = if prep_len >= conv_kernel_size {
prepended.narrow(2, prep_len - conv_kernel_size, conv_kernel_size)?
} else {
let pad = Tensor::zeros(
(batch_size, conv_dim, conv_kernel_size - prep_len),
dtype,
&device,
)?;
Tensor::cat(&[&pad, &prepended], 2)?
};
let conv_out = prepended.conv1d(weight, conv_kernel_size - 1, 1, 1, conv_dim)?;
let conv_out = conv_out.narrow(2, 0, prep_len)?;
let conv_out = candle_nn::ops::silu(&conv_out)?;
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
Ok((conv_out, new_state))
}
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
/// the standard `[out, in]` order.
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}
/// Numerically-stable `softplus(x) = ln(1 + exp(x))`. Matches PyTorch's
/// `F.softplus` default (beta=1, threshold=20: for large positive x,
/// returns x as-is to avoid overflow in the exp).
pub(crate) fn softplus(x: &Tensor) -> candle_core::Result<Tensor> {
let threshold = 20.0_f64;
let big = x.ge(threshold)?; // Tensor<u8> mask
let safe = x.minimum(&x.affine(0.0, 0.0)?.affine(0.0, threshold)?)?; // min(x, threshold)
let small = ((safe.exp()? + 1.0_f64)?).log()?;
// Select x where big, else small.
big.where_cond(x, &small)
}
/// `repeat_interleave` along a single dim. Candle has no built-in for
/// this; emulate with unsqueeze + expand + reshape.
pub(crate) fn repeat_interleave(
x: &Tensor,
repeats: usize,
dim: usize,
) -> candle_core::Result<Tensor> {
if repeats == 1 {
return Ok(x.clone());
}
let mut shape = x.dims().to_vec();
let orig = shape[dim];
shape.insert(dim + 1, repeats);
let mut expanded_shape = shape.clone();
expanded_shape[dim + 1] = repeats;
let x = x.unsqueeze(dim + 1)?;
let x = x.expand(expanded_shape)?;
let mut out_shape = x.dims().to_vec();
out_shape.remove(dim + 1);
out_shape[dim] = orig * repeats;
x.contiguous()?.reshape(out_shape)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
#[test]
fn softplus_small_x() {
// softplus(0) = ln(2) ≈ 0.6931
let x = Tensor::new(&[0.0_f32], &Device::Cpu).unwrap();
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
assert!((out[0] - 2.0_f32.ln()).abs() < 1e-4);
}
#[test]
fn softplus_large_x_returns_x() {
// For x = 30, softplus(x) ≈ x (the threshold branch).
let x = Tensor::new(&[30.0_f32], &Device::Cpu).unwrap();
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
assert!((out[0] - 30.0).abs() < 1e-4);
}
#[test]
fn repeat_interleave_doubles_dim() {
let x = Tensor::new(&[[1.0_f32, 2.0], [3.0, 4.0]], &Device::Cpu).unwrap(); // shape (2, 2)
let out = repeat_interleave(&x, 2, 1).unwrap(); // each col duplicated
let v: Vec<Vec<f32>> = out.to_vec2().unwrap();
// Row 0: 1, 1, 2, 2
// Row 1: 3, 3, 4, 4
assert_eq!(v[0], vec![1.0, 1.0, 2.0, 2.0]);
assert_eq!(v[1], vec![3.0, 3.0, 4.0, 4.0]);
}
/// Sanity: the recurrent path produces a finite tensor of the right
/// shape on tiny dimensions. Doesn't validate numerical correctness
/// against the Python reference — that would need a fixed-weight
/// fixture to compare against. Catches structural mistakes
/// (broadcasting shapes, off-by-one slices) early.
#[test]
fn forward_smoke_with_tiny_dimensions() {
let dev = Device::Cpu;
let dtype = DType::F32;
let (b, l) = (1, 3);
let cfg = TextConfig {
vocab_size: 100,
hidden_size: 16,
intermediate_size: 32,
num_hidden_layers: 1,
num_attention_heads: 4,
num_key_value_heads: 1,
head_dim: 4,
max_position_embeddings: 32,
rope_parameters: RopeParameters {
rope_theta: 10000.0,
partial_rotary_factor: 1.0,
rope_type: None,
},
rms_norm_eps: 1e-6,
tie_word_embeddings: false,
attn_output_gate: true,
layer_types: vec!["linear_attention".into()],
full_attention_interval: Some(4),
hidden_act: "silu".into(),
linear_num_value_heads: 4,
linear_num_key_heads: 2,
linear_key_head_dim: 4,
linear_value_head_dim: 4,
linear_conv_kernel_dim: 4,
};
// Build a synthetic VarBuilder with all-zeros weights.
// Easier path: skip the load and construct GatedDeltaNet
// manually by hand-rolling the Linear/Tensor inputs.
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &dev).unwrap();
let key_dim = cfg.linear_key_head_dim * cfg.linear_num_key_heads;
let value_dim = cfg.linear_value_head_dim * cfg.linear_num_value_heads;
let conv_dim = key_dim * 2 + value_dim;
let mut net = GatedDeltaNet {
in_proj_qkv: Linear::new(zeros(&[conv_dim, cfg.hidden_size]), None),
in_proj_z: Linear::new(zeros(&[value_dim, cfg.hidden_size]), None),
in_proj_b: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
in_proj_a: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
out_proj: Linear::new(zeros(&[cfg.hidden_size, value_dim]), None),
conv1d_weight: zeros(&[conv_dim, 1, cfg.linear_conv_kernel_dim]),
dt_bias: zeros(&[cfg.linear_num_value_heads]),
a_log: zeros(&[cfg.linear_num_value_heads]),
norm: {
let weight = Tensor::ones(&[cfg.linear_value_head_dim], dtype, &dev).unwrap();
Qwen3_5RmsNormGated::from_weight(weight, cfg.rms_norm_eps)
},
num_v_heads: cfg.linear_num_value_heads,
num_k_heads: cfg.linear_num_key_heads,
head_k_dim: cfg.linear_key_head_dim,
head_v_dim: cfg.linear_value_head_dim,
key_dim,
value_dim,
conv_dim,
conv_kernel_size: cfg.linear_conv_kernel_dim,
state: GatedDeltaNetState::default(),
};
let x = Tensor::ones(&[b, l, cfg.hidden_size], dtype, &dev).unwrap();
let y = net.forward(&x).unwrap();
assert_eq!(y.dims(), &[b, l, cfg.hidden_size]);
// All zero weights → output should be zero. Confirms no NaN/Inf
// poisoning from the f32 promotions.
let v: Vec<f32> = y.flatten_all().unwrap().to_vec1().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
}

View File

@@ -1,53 +0,0 @@
//! SwiGLU MLP block for Qwen3-Next.
//!
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
//! no bias on any of the three projections.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use super::TextConfig;
pub struct Qwen3_5MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl Qwen3_5MLP {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let i = cfg.intermediate_size;
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
})
}
}
impl Module for Qwen3_5MLP {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
let rhs = self.up_proj.forward(x)?;
self.down_proj.forward(&(lhs * rhs)?)
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

View File

@@ -1,397 +0,0 @@
//! Qwen3-Next (`model_type = "qwen3_5"`) architecture — Qwen3.6's
//! upstream architecture revision.
//!
//! ## Naming
//!
//! The model release this targets is `Qwen/Qwen3.6-*` but the
//! architecture name in HuggingFace's `config.json` is `qwen3_5`.
//! mistralrs calls the same architecture `qwen3_next`; that label
//! ages poorly the next time Qwen ship a new arch, so we key on the
//! canonical `qwen3_5` from the model's own config.
//!
//! ## Status
//!
//! **Single-GPU dense path is real**. Both attention flavours
//! (`full_attention` with the output-gated GQA causal attention and
//! `linear_attention` with the Gated DeltaNet recurrent block) are
//! implemented. The model loads from upstream safetensors via the
//! existing `load_arch_dense` dispatch and runs forward end to end.
//!
//! Numerical correctness vs the reference Python is **not yet
//! validated** — the structural code path is right, weight tensor
//! names match the upstream layout, shapes flow through cleanly, but
//! the Tbilisi probe (and any other downstream test) is the next
//! step. Likely places a bug would surface:
//! - Per-rank vs per-token-position offsets in the recurrent delta
//! rule (`linear_attn.rs`).
//! - Off-by-one in the conv state continuation across decode steps.
//! - RoPE phase mismatch from MRoPE simplification (we treat the
//! three position grids as collapsed, which is correct only for
//! text-only inference).
//!
//! ## Submodules
//!
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
//! `l2norm` helper.
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
//! rotate-half).
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
//! widening on `q_proj`.
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
//! delta rule → RMSNormGated → out_proj).
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
//! two attention flavours per layer index.
//!
//! ## Open work
//!
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
//! Sharding strategy diverges by layer type:
//! - Full-attention layers: column-parallel q/k/v (including the
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
//! `tp_qwen3.rs`.
//! - Linear-attention layers: the recurrent state is per-V-head, so
//! V-head-dimension sharding works cleanly — split `num_v_heads`
//! across ranks (`num_v_heads / world_size` per rank), shard
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
//! `dt_bias` per-head params shard with the heads.
//!
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
//! per-token recurrent path for prefill too — correct but O(L).
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
//! prefill substantially with no surface change.
use anyhow::{Context, Result};
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::Embedding;
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use serde::Deserialize;
use std::sync::Arc;
pub mod decoder;
pub mod full_attn;
pub mod linear_attn;
pub mod mlp;
pub mod rmsnorm;
pub mod rope;
use decoder::Qwen3_5DecoderLayer;
use rmsnorm::Qwen3_5RmsNorm;
use rope::RotaryEmbedding;
/// `model_type` we deserialise from `config.json`. Const so the
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
/// magic strings.
pub const MODEL_TYPE: &str = "qwen3_5";
/// Top-level shape of Qwen3-Next's `config.json`. The real
/// hyperparameters live in `text_config`; the rest is multimodal /
/// tokeniser glue we don't need for the language-model forward.
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
/// Always `"qwen3_5"` for this architecture. Kept on the struct
/// so the (eventual) dispatch / logging code can show it without
/// re-parsing the JSON.
pub model_type: String,
/// The text-side hyperparameters. Everything we actually need.
pub text_config: TextConfig,
}
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
/// but with the extras Qwen3-Next adds (`attn_output_gate`,
/// `layer_types`, `full_attention_interval`, larger `head_dim`).
#[derive(Debug, Clone, Deserialize)]
pub struct TextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub max_position_embeddings: usize,
/// Nested RoPE settings. Qwen3-Next puts `rope_theta` and
/// `partial_rotary_factor` inside this block rather than at the
/// top level — important because the partial rotary means only
/// `head_dim * partial_rotary_factor` dims get RoPE applied (the
/// rest pass through unchanged).
pub rope_parameters: RopeParameters,
pub rms_norm_eps: f64,
#[serde(default)]
pub tie_word_embeddings: bool,
/// New in Qwen3-Next: a sigmoid gate multiplied into the attention
/// output before the o_proj. The Python reference applies it
/// pointwise after softmax+matmul.
#[serde(default)]
pub attn_output_gate: bool,
/// One entry per decoder layer; values are `"full_attention"` or
/// `"linear_attention"`. Length must equal `num_hidden_layers`.
/// `full_attention_interval` is a derived hint (every 4th layer
/// by default) — `layer_types` is authoritative.
#[serde(default)]
pub layer_types: Vec<String>,
/// Hint for the layer-type pattern (defaults to 4). Kept for
/// logging / validation; the forward dispatches on `layer_types`.
#[serde(default)]
pub full_attention_interval: Option<usize>,
/// Hidden activation (`"silu"` for Qwen3-Next). Used by the MLP
/// and the linear-attention conv1d.
#[serde(default = "default_hidden_act")]
pub hidden_act: String,
// --- Gated DeltaNet (linear-attention) hyperparams -----------------
/// Per-layer linear-attention V-head count (Qwen3.6-27B: 48).
/// More V-heads than K-heads is fine — query/key get
/// `repeat_interleave`'d to match before the delta rule.
#[serde(default)]
pub linear_num_value_heads: usize,
/// Per-layer linear-attention K-head count (Qwen3.6-27B: 16).
#[serde(default)]
pub linear_num_key_heads: usize,
/// Per-head key dimension for the linear-attention path
/// (Qwen3.6-27B: 128). Separate from `head_dim` which the
/// full-attention layers use.
#[serde(default)]
pub linear_key_head_dim: usize,
/// Per-head value dimension for the linear-attention path
/// (Qwen3.6-27B: 128).
#[serde(default)]
pub linear_value_head_dim: usize,
/// Causal Conv1d kernel size used before the delta rule
/// (Qwen3.6-27B: 4).
#[serde(default)]
pub linear_conv_kernel_dim: usize,
}
fn default_hidden_act() -> String {
"silu".into()
}
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
/// `mrope_section` and `mrope_interleaved` are accepted via the
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat
/// MRoPE as plain RoPE for text-only inference (the three position
/// grids carry identical ids when there's no vision input, so the
/// interleaving is a no-op).
#[derive(Debug, Clone, Deserialize)]
pub struct RopeParameters {
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
#[serde(default = "default_rope_theta")]
pub rope_theta: f64,
/// Fraction of `head_dim` that gets the rotation applied. The
/// remaining `head_dim * (1 - partial_rotary_factor)` dims pass
/// through unchanged. Qwen3.6 / Qwen3.5: 0.25.
#[serde(default = "default_partial_rotary_factor")]
pub partial_rotary_factor: f32,
/// `"default"` for the standard inv_freq RoPE; other values (e.g.
/// `"linear"`, `"dynamic"`) are upstream-supported but not yet
/// implemented here.
#[serde(default)]
pub rope_type: Option<String>,
}
fn default_rope_theta() -> f64 {
10_000.0
}
fn default_partial_rotary_factor() -> f32 {
1.0
}
/// Qwen3-Next base transformer (embedding + decoder stack + final
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
/// loaded handle.
pub struct Qwen3_5Model {
embed_tokens: Embedding,
layers: Vec<Qwen3_5DecoderLayer>,
norm: Qwen3_5RmsNorm,
device: Device,
dtype: DType,
}
impl Qwen3_5Model {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
// Qwen3-Next is a multimodal architecture whose text core lives
// under `model.language_model.*` — sibling to `model.visual.*`
// (the vision tower) and to top-level `lm_head` / `mtp.*`.
// Every text-side tensor in the safetensors files is under
// this prefix; we ignore the vision and MTP weights for
// language-model inference.
let text_vb = vb.pp("model.language_model");
let embed_vb = text_vb.pp("embed_tokens");
let embed_weight = embed_vb
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
if cfg.layer_types.len() != cfg.num_hidden_layers {
anyhow::bail!(
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
got {}",
cfg.num_hidden_layers,
cfg.layer_types.len()
);
}
let vb_l = text_vb.pp("layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(Qwen3_5DecoderLayer::load(
cfg,
rotary.clone(),
i,
&vb_l.pp(i),
)?);
}
let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
pub fn embed_weight(&self) -> &Tensor {
self.embed_tokens.embeddings()
}
pub fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
// Causal mask only needed for L > 1 prefill; full-attention
// layers consume it via broadcast_add. Linear-attention layers
// ignore the mask.
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
pub struct Qwen3_5ForCausalLM {
base: Qwen3_5Model,
lm_head: Linear,
}
impl Qwen3_5ForCausalLM {
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
let cfg = &config.text_config;
let base = Qwen3_5Model::load(cfg, &vb)?;
let lm_head = if cfg.tie_word_embeddings {
Linear::new(base.embed_weight().clone(), None)
} else {
let weight = vb
.pp("lm_head")
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
Linear::new(weight, None)
};
Ok(Self { base, lm_head })
}
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
/// the last position, shape `(B, 1, vocab_size)` — same contract
/// as `qwen3::ModelForCausalLM::forward` so the harness's
/// `squeeze_to_vocab` helper handles both uniformly.
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self.base.forward(input, offset)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Confirms we can deserialise the real upstream config shape.
/// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to
/// the fields the architecture cares about. Note `rope_theta` and
/// `partial_rotary_factor` are nested under `rope_parameters` —
/// Qwen3-Next does NOT have a top-level `rope_theta`.
#[test]
fn config_deserialises_the_real_qwen3_6_shape() {
let raw = r#"{
"architectures": ["Qwen3_5ForConditionalGeneration"],
"model_type": "qwen3_5",
"image_token_id": 248056,
"language_model_only": false,
"text_config": {
"vocab_size": 248064,
"hidden_size": 5120,
"intermediate_size": 17408,
"num_hidden_layers": 64,
"num_attention_heads": 64,
"num_key_value_heads": 8,
"head_dim": 256,
"max_position_embeddings": 32768,
"rope_parameters": {
"mrope_interleaved": true,
"mrope_section": [11, 11, 10],
"partial_rotary_factor": 0.25,
"rope_theta": 10000000,
"rope_type": "default"
},
"rms_norm_eps": 1e-6,
"tie_word_embeddings": false,
"attn_output_gate": true,
"full_attention_interval": 4,
"layer_types": [
"linear_attention", "linear_attention",
"linear_attention", "full_attention"
]
}
}"#;
let cfg: Config = serde_json::from_str(raw).expect("parse Qwen3.6 config");
assert_eq!(cfg.model_type, "qwen3_5");
assert_eq!(cfg.text_config.hidden_size, 5120);
assert_eq!(cfg.text_config.head_dim, 256);
assert!(cfg.text_config.attn_output_gate);
assert_eq!(cfg.text_config.full_attention_interval, Some(4));
assert_eq!(cfg.text_config.layer_types.len(), 4);
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
}
}

View File

@@ -1,161 +0,0 @@
//! Norm primitives for Qwen3-Next.
//!
//! Two reasons we can't reuse `candle_nn::RmsNorm` directly:
//!
//! 1. **`(1.0 + weight)` scaling.** Qwen3-Next's `Qwen3_5RMSNorm`
//! initialises `weight` to zeros and applies `(1.0 + weight)` to
//! the normalised vector. `candle_nn::RmsNorm` applies `weight`
//! directly. The two are equivalent only when the operator has
//! pre-shifted the weights — the upstream checkpoints have not. See
//! `huggingface/transformers#29402` for the upstream PR that
//! introduced the `(1 + w)` form to recover from the zero-init.
//!
//! 2. **Gated variant.** The linear-attention layer post-normalises
//! its output by an RMSNorm *gated* with a per-element SiLU on
//! a sibling `z` projection — fused for numerical reasons (the
//! norm's float32 promotion has to happen before the SiLU
//! multiply). Not a single existing candle op.
//!
//! Both ops accept inputs in any compute dtype; promotion to f32 for
//! the variance calculation matches the Python reference.
use anyhow::{Context, Result};
use candle_core::{D, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
/// L2-normalise along the last dim with a small epsilon. Matches the
/// `l2norm` helper in `transformers/models/qwen3_5/modeling_qwen3_5.py`
/// — `x * rsqrt(sum(x*x) + eps)`. The linear-attention path uses this
/// on Q and K before the delta rule when
/// `use_qk_l2norm_in_kernel=True` (which Qwen3-Next always sets).
pub fn l2norm(x: &Tensor, eps: f32) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let sq = x_f32.sqr()?;
let sum = sq.sum_keepdim(D::Minus1)?;
let inv = (sum + eps as f64)?.sqrt()?.recip()?;
x_f32.broadcast_mul(&inv)?.to_dtype(dtype)
}
/// Qwen3-Next's RMSNorm. Stores the raw weight tensor; forward applies
/// `(1.0 + weight) * x_normed`.
pub struct Qwen3_5RmsNorm {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNorm {
/// Load `weight` from the ShardedVarBuilder. `vb` should already be
/// `.pp(...)`-ed to the norm's tensor prefix.
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
pub fn size(&self) -> usize {
self.size
}
}
impl Module for Qwen3_5RmsNorm {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
// Promote weight to f32 and shift by 1.0 *before* multiplying.
// Doing the (1 + w) operation in fp16 lands at -inf for the
// bottom-of-range weights at load time.
let w_f32 = self.weight.to_dtype(candle_core::DType::F32)?;
let scale = (w_f32 + 1.0_f64)?;
normed.broadcast_mul(&scale)?.to_dtype(dtype)
}
}
/// Gated RMSNorm used at the tail of `Qwen3_5GatedDeltaNet`. Equivalent
/// to `x_normed * weight * silu(gate)` but with both the norm and the
/// gate evaluated in float32 to avoid mid-pipeline underflow.
///
/// Note: unlike `Qwen3_5RmsNorm`, this variant matches the Python
/// reference's `Qwen3_5RMSNormGated` which uses `weight` directly (not
/// `1.0 + weight`).
pub struct Qwen3_5RmsNormGated {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNormGated {
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
/// Direct constructor — used by unit tests that build a layer
/// without going through a VarBuilder.
#[cfg(test)]
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
let size = weight.dims()[0];
Self {
weight,
eps: eps as f32,
size,
}
}
pub fn size(&self) -> usize {
self.size
}
/// `x` and `gate` share the same last-dim shape (`size`).
pub fn forward(&self, x: &Tensor, gate: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
let w = self.weight.to_dtype(candle_core::DType::F32)?;
let out = normed.broadcast_mul(&w)?;
// SiLU on the float32 gate, multiply back into the normed
// tensor, then cast to the model dtype.
let g = gate.to_dtype(candle_core::DType::F32)?;
let silu_gate = candle_nn::ops::silu(&g)?;
(out * silu_gate)?.to_dtype(dtype)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn l2norm_matches_hand_calc() {
let x = Tensor::new(&[3.0_f32, 4.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
// |x| = 5, so x/|x| = [0.6, 0.8] (eps is tiny).
assert!((v[0] - 0.6).abs() < 1e-4);
assert!((v[1] - 0.8).abs() < 1e-4);
}
#[test]
fn l2norm_zero_vector_is_safe_via_epsilon() {
let x = Tensor::new(&[0.0_f32, 0.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
}

View File

@@ -1,114 +0,0 @@
//! Rotary position embedding for Qwen3-Next's full-attention layers.
//!
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
//! reference Python — three position grids interleaved per
//! `mrope_section`. For text-only inference all three grids carry the
//! same position ids and the interleave is a no-op, so this module
//! implements the plain (non-mrope) flavour: the standard inv_freq
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
//!
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
//! head dim is negated and swapped into the first). The reference
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
//! `rope_slow` is the matching helper.
use anyhow::Result;
use candle_core::{DType, Device, Tensor};
use super::TextConfig;
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
/// Number of dims at the head's leading edge that the rotation
/// covers. The remaining `head_dim - rotary_dim` dims pass through
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
/// for `head_dim = 256` only 64 dims rotate.
rotary_dim: usize,
head_dim: usize,
}
impl RotaryEmbedding {
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
let head_dim = cfg.head_dim;
let rope = &cfg.rope_parameters;
let rotary_dim = (head_dim as f32 * rope.partial_rotary_factor) as usize;
if !rotary_dim.is_multiple_of(2) {
anyhow::bail!(
"rotary_dim = head_dim * partial_rotary_factor = {head_dim} * {} = {rotary_dim} \
must be even (cos/sin are paired)",
rope.partial_rotary_factor
);
}
if rotary_dim == 0 {
anyhow::bail!(
"rotary_dim = 0 (partial_rotary_factor = {} too small)",
rope.partial_rotary_factor
);
}
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<f32> = (0..rotary_dim)
.step_by(2)
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
.collect();
let n = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
rotary_dim,
head_dim,
})
}
/// Apply RoPE to q, k.
///
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
/// into the cached cos/sin table — the position of the first token
/// in the current step.
///
/// When `rotary_dim < head_dim` the rotation is applied only to the
/// first `rotary_dim` dims of each head; the tail passes through
/// unchanged (matches the reference Python's
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`).
pub fn apply(
&self,
q: &Tensor,
k: &Tensor,
offset: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, head_dim_in) = q.dims4()?;
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
if self.rotary_dim == self.head_dim {
// Full rotation.
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
} else {
// Partial rotation: narrow → rotate → cat the untouched tail.
let tail = self.head_dim - self.rotary_dim;
let q_rot = q
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
.contiguous()?;
let q_pass = q.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
let k_rot = k
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
.contiguous()?;
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?;
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?;
let q_embed =
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
let k_embed =
Tensor::cat(&[&k_rotated, &k_pass.contiguous()?], candle_core::D::Minus1)?;
Ok((q_embed, k_embed))
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
// llama.cpp harness implementation — Phase 11.

View File

@@ -0,0 +1,163 @@
//! mistral.rs harness implementation.
//!
//! Wraps the mistral.rs HTTP API for model lifecycle management
//! and optionally manages the process via systemd.
use anyhow::Result;
use async_trait::async_trait;
use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec};
use reqwest::Client;
use serde::Deserialize;
pub struct MistralRsHarness {
endpoint: String,
systemd_unit: Option<String>,
client: Client,
}
impl MistralRsHarness {
pub fn new(endpoint: String, systemd_unit: Option<String>) -> Self {
Self {
endpoint,
systemd_unit,
client: Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("failed to build HTTP client"),
}
}
}
/// Response from mistral.rs `GET /v1/models`.
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelEntry>,
}
#[derive(Debug, Deserialize)]
struct ModelEntry {
id: String,
#[serde(default)]
status: Option<String>,
}
#[async_trait]
impl Harness for MistralRsHarness {
fn name(&self) -> &str {
"mistralrs"
}
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
let Some(unit) = &self.systemd_unit else {
anyhow::bail!("no systemd unit configured for mistralrs harness");
};
let output = tokio::process::Command::new("systemctl")
.args(["start", unit])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl start {unit} failed: {stderr}");
}
// Wait for the health endpoint to respond (up to 30s).
let url = format!("{}/health", self.endpoint);
for _ in 0..30 {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if self.client.get(&url).send().await.is_ok() {
tracing::info!(unit, "mistralrs started and healthy");
return Ok(());
}
}
anyhow::bail!("mistralrs started but health endpoint did not respond within 30s");
}
async fn stop(&self) -> Result<()> {
let Some(unit) = &self.systemd_unit else {
anyhow::bail!("no systemd unit configured for mistralrs harness");
};
let output = tokio::process::Command::new("systemctl")
.args(["stop", unit])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl stop {unit} failed: {stderr}");
}
Ok(())
}
async fn health(&self) -> HarnessHealth {
let url = format!("{}/health", self.endpoint);
let running = self.client.get(&url).send().await.is_ok();
HarnessHealth {
name: "mistralrs".into(),
running,
uptime_secs: None,
}
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let url = format!("{}/v1/models", self.endpoint);
let resp = self.client.get(&url).send().await?;
if !resp.status().is_success() {
anyhow::bail!("GET /v1/models returned {}", resp.status());
}
let models_resp: ModelsResponse = resp.json().await?;
Ok(models_resp
.data
.into_iter()
.map(|m| ModelInfo {
id: m.id,
harness: "mistralrs".into(),
status: m.status.unwrap_or_else(|| "loaded".into()),
devices: vec![],
vram_used_mb: None,
})
.collect())
}
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
let url = format!("{}/v1/models/reload", self.endpoint);
let resp = self
.client
.post(&url)
.json(&serde_json::json!({ "model_id": spec.model_id }))
.send()
.await?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("POST /v1/models/reload failed: {body}");
}
Ok(())
}
async fn unload_model(&self, model_id: &str) -> Result<()> {
let url = format!("{}/v1/models/unload", self.endpoint);
let resp = self
.client
.post(&url)
.json(&serde_json::json!({ "model_id": model_id }))
.send()
.await?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("POST /v1/models/unload failed: {body}");
}
Ok(())
}
async fn inference_endpoint(&self, _model_id: &str) -> Option<String> {
// mistral.rs routes internally by model name in the request body,
// so the inference endpoint is always the base URL.
Some(self.endpoint.clone())
}
}

View File

@@ -1,24 +1,15 @@
//! Harness registry — maps harness names to trait implementations.
pub mod arch;
pub mod candle;
pub mod tp;
pub mod llamacpp;
pub mod mistralrs;
use anyhow::Result;
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
use std::collections::HashMap;
use std::sync::Arc;
/// Registry of available harness implementations.
///
/// Holds an `Arc<dyn Harness>` per harness for generic lifecycle dispatch
/// (load/unload/list_models). When a candle harness is registered, a typed
/// `Arc<CandleHarness>` is also cached so inference routes can bypass the
/// dyn-Trait dispatch and reach harness-specific methods (chat completion,
/// streaming, etc.).
pub struct HarnessRegistry {
harnesses: HashMap<String, Arc<dyn Harness>>,
candle: Option<Arc<candle::CandleHarness>>,
harnesses: HashMap<String, Box<dyn Harness>>,
}
impl Default for HarnessRegistry {
@@ -31,11 +22,10 @@ impl HarnessRegistry {
pub fn new() -> Self {
Self {
harnesses: HashMap::new(),
candle: None,
}
}
pub fn register(&mut self, harness: Arc<dyn Harness>) {
pub fn register(&mut self, harness: Box<dyn Harness>) {
self.harnesses.insert(harness.name().to_string(), harness);
}
@@ -44,12 +34,6 @@ impl HarnessRegistry {
self.harnesses.keys().cloned().collect()
}
/// Typed handle to the candle harness, if registered. Used by inference
/// routes that need methods beyond the `Harness` trait surface.
pub fn candle(&self) -> Option<Arc<candle::CandleHarness>> {
self.candle.clone()
}
/// List models from all registered harnesses.
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
let mut all = Vec::new();
@@ -97,25 +81,19 @@ impl HarnessRegistry {
}
/// Build a registry from harness configs.
///
/// `bind_url` is the URL where this neuron serves inference (its own
/// listen address). In-process harnesses (currently the only kind)
/// return this URL from `inference_endpoint`.
pub fn from_configs(
configs: &[HarnessConfig],
bind_url: &str,
settings: &crate::config::HarnessSettings,
) -> Self {
pub fn from_configs(configs: &[HarnessConfig]) -> Self {
let mut registry = Self::new();
for config in configs {
match config.name.as_str() {
"candle" => {
let harness = Arc::new(candle::CandleHarness::new(
bind_url.to_string(),
settings.candle.hf_cache.clone(),
));
registry.candle = Some(Arc::clone(&harness));
registry.harnesses.insert("candle".into(), harness);
"mistralrs" => {
if let Some(endpoint) = &config.endpoint {
registry.register(Box::new(mistralrs::MistralRsHarness::new(
endpoint.clone(),
config.systemd_unit.clone(),
)));
} else {
tracing::warn!("mistralrs harness missing endpoint, skipping");
}
}
other => {
tracing::warn!(harness = other, "unknown harness type, skipping");

View File

@@ -1,119 +0,0 @@
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
//!
//! Ported from the canonical
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
//! Row-parallel layers apply this op after their local matmul to sum
//! partial outputs across NCCL ranks.
//!
//! Available only under `--features cuda`; on CPU builds this module
//! is empty and row-parallel layers degenerate to local matmul only
//! (useful for compile-checking the model code; correctness requires
//! cuda).
//!
//! Thread-safety caveat: NCCL communicators are technically only
//! safe to use from a single thread at a time
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
//! against it from the dedicated `spawn_blocking` thread the inference
//! pipeline already uses for candle's forward passes.
#![cfg(feature = "cuda")]
use candle_core::backend::BackendStorage;
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
use cudarc::nccl::{Comm, ReduceOp};
use half::{bf16, f16};
use std::sync::Arc;
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
/// graph as a custom op. Each row-parallel layer holds one of these.
pub struct AllReduce {
comm: Arc<Comm>,
}
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
// that issuing ops against one comm from multiple threads concurrently
// is unsafe. We serialise via the single spawn_blocking thread that
// drives the model's forward pass. The Send/Sync impl is necessary
// because candle's CustomOp1 trait bounds require it; the correctness
// invariant is enforced at the call site, not the type level.
unsafe impl Send for AllReduce {}
unsafe impl Sync for AllReduce {}
impl AllReduce {
pub fn new(comm: Arc<Comm>) -> Self {
Self { comm }
}
pub fn comm(&self) -> &Arc<Comm> {
&self.comm
}
}
impl CustomOp1 for AllReduce {
fn name(&self) -> &'static str {
"neuron.tp.all_reduce"
}
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
}
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
// Reject non-contiguous inputs explicitly — copying them
// server-side would mask shape bugs (a TP layer feeding a
// strided activation into all_reduce is almost certainly a
// model construction error).
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
slice: &cudarc::driver::CudaSlice<T>,
l: &Layout,
) -> Result<()> {
match l.contiguous_offsets() {
Some((0, n)) if n == slice.len() => Ok(()),
_ => candle_core::bail!(
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
l,
slice.len()
),
}
}
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let out = match s.dtype() {
DType::BF16 => {
let src = s.as_cuda_slice::<bf16>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F16 => {
let src = s.as_cuda_slice::<f16>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F32 => {
let src = s.as_cuda_slice::<f32>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
dtype => candle_core::bail!(
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
),
};
Ok((out, l.shape().clone()))
}
}

View File

@@ -1,213 +0,0 @@
//! Direct safetensors readers for fused-region weight tensors.
//!
//! Qwen3-Next's `in_proj_qkv` and `conv1d` weights are *fused* —
//! three regions stored sequentially along dim 0 (`[key_q, key_k,
//! value]`). The per-rank shard for each region has unequal size
//! (`key_dim/ws` vs `value_dim/ws`), so candle's `ShardedSafeTensors`
//! built-in `Shard { dim, rank, world_size }` (uniform split) doesn't
//! map to the right slices.
//!
//! The previous approach loaded the full fused tensor onto the device,
//! `narrow`ed the three regions, and `Tensor::cat(...).contiguous()`'d
//! the per-rank slice. That left ~100 MB of transient device memory
//! per linear-attention layer — 48 layers × 100 MB = ~4.8 GB of
//! allocator pressure during load, enough to trigger fragmentation
//! OOM on tight-VRAM consumer GPUs.
//!
//! This module reads the three per-rank byte ranges *directly from
//! the safetensors mmap* (host-side), concatenates them into a single
//! contiguous byte buffer, and uploads as one device allocation. No
//! full-tensor device materialisation.
use anyhow::{Context, Result, bail};
use candle_core::safetensors::MmapedSafetensors;
use candle_core::{DType, Device, Tensor};
/// Read a 2D fused-QKV tensor `[conv_dim, hidden_size]` and return
/// this rank's per-region slice as a `[per_rank_conv_dim, hidden_size]`
/// device tensor.
///
/// `tensor_name` must be the fully-qualified safetensors key (e.g.
/// `"model.language_model.layers.5.linear_attn.in_proj_qkv.weight"`).
#[allow(clippy::too_many_arguments)]
pub fn load_fused_qkv_2d(
mmap: &MmapedSafetensors,
tensor_name: &str,
hidden_size: usize,
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
target_dtype: DType,
device: &Device,
) -> Result<Tensor> {
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
let view = mmap
.get(tensor_name)
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 2D"))?;
let view_dtype: DType = view
.dtype()
.try_into()
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
let shape = view.shape();
if shape.len() != 2 {
bail!(
"fused qkv tensor '{tensor_name}' has shape {shape:?}, expected 2D \
[conv_dim, hidden_size]"
);
}
let conv_dim = key_dim * 2 + value_dim;
if shape[0] != conv_dim || shape[1] != hidden_size {
bail!(
"fused qkv tensor '{tensor_name}' shape {shape:?} \
doesn't match expected [{conv_dim}, {hidden_size}]"
);
}
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
let k_bytes = slice_dim0_bytes(
&view,
key_dim + r * per_rank_key,
per_rank_key,
tensor_name,
"k",
)?;
let v_bytes = slice_dim0_bytes(
&view,
2 * key_dim + r * per_rank_value,
per_rank_value,
tensor_name,
"v",
)?;
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
bytes.extend_from_slice(&q_bytes);
bytes.extend_from_slice(&k_bytes);
bytes.extend_from_slice(&v_bytes);
let tensor = Tensor::from_raw_buffer(
&bytes,
view_dtype,
&[per_rank_conv_dim, hidden_size],
device,
)
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused qkv '{tensor_name}'"))?;
tensor
.to_dtype(target_dtype)
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
}
/// Read a 3D fused-QKV tensor `[conv_dim, 1, kernel_size]` (the
/// depthwise conv1d weight) and return this rank's per-region slice
/// as a `[per_rank_conv_dim, 1, kernel_size]` device tensor.
#[allow(clippy::too_many_arguments)]
pub fn load_fused_qkv_3d(
mmap: &MmapedSafetensors,
tensor_name: &str,
mid: usize,
kernel_size: usize,
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
target_dtype: DType,
device: &Device,
) -> Result<Tensor> {
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
let view = mmap
.get(tensor_name)
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 3D"))?;
let view_dtype: DType = view
.dtype()
.try_into()
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
let shape = view.shape();
if shape.len() != 3 {
bail!(
"fused conv tensor '{tensor_name}' has shape {shape:?}, expected 3D \
[conv_dim, mid, kernel_size]"
);
}
let conv_dim = key_dim * 2 + value_dim;
if shape[0] != conv_dim || shape[1] != mid || shape[2] != kernel_size {
bail!(
"fused conv tensor '{tensor_name}' shape {shape:?} \
doesn't match expected [{conv_dim}, {mid}, {kernel_size}]"
);
}
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
let k_bytes = slice_dim0_bytes(
&view,
key_dim + r * per_rank_key,
per_rank_key,
tensor_name,
"k",
)?;
let v_bytes = slice_dim0_bytes(
&view,
2 * key_dim + r * per_rank_value,
per_rank_value,
tensor_name,
"v",
)?;
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
bytes.extend_from_slice(&q_bytes);
bytes.extend_from_slice(&k_bytes);
bytes.extend_from_slice(&v_bytes);
let tensor = Tensor::from_raw_buffer(
&bytes,
view_dtype,
&[per_rank_conv_dim, mid, kernel_size],
device,
)
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused conv '{tensor_name}'"))?;
tensor
.to_dtype(target_dtype)
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
}
/// Read `len` consecutive rows along dim 0 starting at `start` from
/// the safetensors view, returning the raw bytes. Wraps the same
/// `view.slice(start..stop)` machinery that candle's
/// `ShardedSafeTensors::get` uses internally.
fn slice_dim0_bytes(
view: &safetensors::tensor::TensorView<'_>,
start: usize,
len: usize,
tensor_name: &str,
region: &str,
) -> Result<Vec<u8>> {
use safetensors::slice::IndexOp;
let stop = start + len;
let iter = view.slice(start..stop).map_err(|e| {
anyhow::anyhow!("slice '{tensor_name}' region {region} ({start}..{stop}): {e:?}")
})?;
Ok(iter.into_iter().flatten().copied().collect())
}

View File

@@ -1,791 +0,0 @@
//! Tensor-parallel inference plumbing.
//!
//! The leader process (the neuron daemon proper) drives one
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
//! and talks to each over a newline-delimited JSON RPC channel on
//! the worker's stdin/stdout (see `rpc.rs`).
//!
//! Sub-staging:
//!
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
//! forks N workers; `ping` round-trips every worker to confirm
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
//! `NcclSanityCheck` are stubbed.
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
//! `NcclSanityCheck`. CUDA-gated.
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
//! - **7c:** crash detection, streaming SSE, graceful unload.
pub mod all_reduce;
pub mod fused_load;
pub mod nccl_state;
pub mod rpc;
pub mod tp_linear;
pub mod tp_qwen3;
pub mod tp_qwen3_5;
pub mod worker;
use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use rpc::{WorkerRequest, WorkerResponse};
/// Leader-side handle for any TP-loaded model. The pool's
/// `load_dense_shard` dispatches on `config.json#/model_type` to build
/// the right variant; downstream callers (the harness's
/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`,
/// `unload_model`) all hold this enum and let the variant dispatch
/// determine the concrete forward.
///
/// Variants gated on `cuda` because the underlying TP models hold
/// `Arc<cudarc::nccl::Comm>` references — irrelevant on CPU builds.
#[cfg(feature = "cuda")]
pub enum TpLeaderModel {
Qwen3(tp_qwen3::TpQwen3ForCausalLM),
Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM),
}
#[cfg(feature = "cuda")]
impl TpLeaderModel {
pub fn forward(
&mut self,
input: &candle_core::Tensor,
offset: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
TpLeaderModel::Qwen3(m) => m.forward(input, offset),
TpLeaderModel::Qwen3_5(m) => m.forward(input, offset),
}
}
pub fn clear_kv_cache(&mut self) {
match self {
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(),
}
}
pub fn device(&self) -> &candle_core::Device {
match self {
TpLeaderModel::Qwen3(m) => m.device(),
TpLeaderModel::Qwen3_5(m) => m.device(),
}
}
}
/// One worker subprocess plus its bidirectional stdio handles.
struct Worker {
rank: u32,
/// Captured so the leader can log "spawned rank N on device M" and
/// future stages can re-issue Init after a CUDA reset. Unused in
/// the Stage 7a-i RPC paths themselves.
#[allow(dead_code)]
cuda_device: u32,
child: Child,
stdin: ChildStdin,
stdout: Lines<BufReader<ChildStdout>>,
}
impl Worker {
/// Send a request and wait for the response. Used for sequenced
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
/// overlap the worker's execution with the leader's.
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
self.send_only(req).await?;
self.recv_only().await
}
/// Write a request without awaiting its response. Pair with
/// `recv_only` from the caller when leader and worker need to do
/// work concurrently — e.g. during `Init`, where the leader
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
/// workers, then collects `InitOk` after NCCL completes.
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
line.push('\n');
self.stdin
.write_all(line.as_bytes())
.await
.with_context(|| format!("write request to rank {}", self.rank))?;
self.stdin
.flush()
.await
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
Ok(())
}
async fn recv_only(&mut self) -> Result<WorkerResponse> {
let reply = self
.stdout
.next_line()
.await
.with_context(|| format!("read reply from rank {}", self.rank))?
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
serde_json::from_str(&reply)
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
}
}
/// Drain one response from every worker, classifying each via the
/// supplied checker. Always reads from every worker — even if some
/// fail — so the next call's recv doesn't pick up stale responses
/// from this one (pipe-poisoning was the cause of the
/// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class
/// of bugs).
///
/// Returns a vector of `rank N: detail` strings for any worker that
/// errored, expected-mismatched, or failed to respond. Caller decides
/// how to combine these with the leader's outcome.
async fn drain_workers(
workers: &mut [Worker],
mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>,
) -> Vec<String> {
let mut errs = Vec::new();
for w in workers {
match w.recv_only().await {
Ok(resp) => {
if let Err(detail) = check(resp) {
errs.push(format!("rank {} {detail}", w.rank));
}
}
Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)),
}
}
errs
}
/// Combine a leader's `Result<Result<T>>` (the typical
/// `spawn_blocking → JoinHandle<Result<T>>` shape) with the worker
/// drain results into a single `Result<T>`. Leader failures take
/// precedence in the error message but worker errors get appended so
/// the operator sees both halves.
#[cfg(feature = "cuda")]
fn combine_leader_workers<T>(
leader: Result<Result<T>>,
worker_errors: Vec<String>,
op: &str,
) -> Result<T> {
match leader {
Ok(Ok(value)) => {
if worker_errors.is_empty() {
Ok(value)
} else {
anyhow::bail!(
"{op}: leader succeeded but workers failed: {}",
worker_errors.join("; ")
)
}
}
Ok(Err(e)) => {
if worker_errors.is_empty() {
Err(e.context(format!("{op}: leader forward failed")))
} else {
Err(e.context(format!(
"{op}: leader forward failed and workers also failed: {}",
worker_errors.join("; ")
)))
}
}
Err(panic_err) => {
if worker_errors.is_empty() {
Err(panic_err)
} else {
Err(panic_err.context(format!(
"{op}: leader task panicked and workers failed: {}",
worker_errors.join("; ")
)))
}
}
}
}
/// A live pool of worker subprocesses. Owns the `Child` handles so
/// dropping the pool kills the children; explicit `shutdown()` is
/// the graceful path.
pub struct WorkerPool {
world_size: u32,
workers: Vec<Worker>,
/// Path to the neuron binary used to launch workers.
#[allow(dead_code)]
exe: PathBuf,
/// Leader's own NCCL rank-0 state. Defaults to empty; populated by
/// `init_nccl()`. Held here so the leader can participate in
/// collectives (rank 0) without spawning a fourth subprocess.
leader_nccl: nccl_state::NcclState,
}
impl WorkerPool {
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
/// leader (in-process) and is *not* spawned here — the leader
/// holds rank 0's NCCL Comm and shard in its own address space.
///
/// `binary` is the path to the neuron executable to run for each
/// worker (production passes `/proc/self/exe`; tests pass the
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
/// `cuda_devices` is one entry per rank including rank 0. Worker
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result<Self> {
if world_size < 2 {
anyhow::bail!(
"WorkerPool::spawn called with world_size={world_size}; \
use the single-process path for world_size < 2"
);
}
if cuda_devices.len() as u32 != world_size {
anyhow::bail!(
"expected {world_size} cuda_devices entries, got {}",
cuda_devices.len()
);
}
let exe = binary.to_path_buf();
let mut workers = Vec::with_capacity(world_size as usize - 1);
// Rank 0 stays in-process. Spawn ranks 1..world_size.
for rank in 1..world_size {
let cuda_device = cuda_devices[rank as usize];
let mut cmd = Command::new(&exe);
cmd.arg("--worker")
.arg("--rank")
.arg(rank.to_string())
.arg("--tp-size")
.arg(world_size.to_string())
.arg("--cuda-device")
.arg(cuda_device.to_string())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
// Inherit stderr so worker tracing surfaces alongside
// the leader's journalctl stream.
.stderr(Stdio::inherit())
.kill_on_drop(true);
let mut child = cmd
.spawn()
.with_context(|| format!("spawn worker rank {rank}"))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
let stdout = BufReader::new(stdout).lines();
workers.push(Worker {
rank,
cuda_device,
child,
stdin,
stdout,
});
tracing::info!(rank, cuda_device, "spawned tp worker");
}
Ok(Self {
world_size,
workers,
exe,
leader_nccl: nccl_state::NcclState::new(),
})
}
/// Establish the NCCL communicator across the leader (rank 0) and
/// every worker subprocess. Rendezvous is via a freshly-generated
/// `Id` broadcast over the RPC stream; the actual handshake blocks
/// inside `Comm::from_rank` until all `world_size` ranks check in.
///
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
/// to — typically the first entry of the `cuda_devices` slice
/// originally passed to `spawn()`.
///
/// On the non-cuda build this immediately fails because the leader
/// can't generate an `Id` without libnccl. The same call works in
/// the worker path (returning a no-cuda error response) so the
/// failure surface is uniform.
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
let comm_id = nccl_state::generate_comm_id_hex()
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
// 1. Write Init to every worker's stdin without awaiting the
// response. Workers will parse and call Comm::from_rank
// concurrently with the leader below.
for w in &mut self.workers {
let req = WorkerRequest::Init {
comm_id: comm_id.clone(),
};
w.send_only(&req).await?;
}
// 2. Leader rank 0 calls Comm::from_rank on its own device.
// Runs on spawn_blocking because NCCL's init blocks until
// every rank has called in — that's exactly the workers
// above. The leader's NcclState is moved through the
// blocking task and returned to the pool.
let leader_cfg = worker::WorkerConfig {
rank: 0,
world_size: self.world_size,
cuda_device: leader_cuda_device,
};
let comm_id_for_leader = comm_id.clone();
// Swap out the leader's NcclState into a fresh empty one so we
// can move it into spawn_blocking; restore after the task
// returns. (NcclState isn't Clone — it owns a real NCCL Comm.)
let mut leader_state = std::mem::take(&mut self.leader_nccl);
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
let resp = leader_state.init(leader_cfg, &comm_id_for_leader);
(leader_state, resp)
})
.await
.context("leader NCCL init task panicked")?;
self.leader_nccl = returned_state;
match leader_resp {
rpc::WorkerResponse::InitOk => {}
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
}
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
}
// 3. Read InitOk from each worker. By now every worker has
// completed its Comm::from_rank call (NCCL released them
// when the leader joined the handshake) and is writing its
// response.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match &resp {
rpc::WorkerResponse::InitOk => {}
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
}
other => anyhow::bail!(
"worker rank {} init: expected InitOk, got {other:?}",
w.rank
),
}
}
tracing::info!(
world_size = self.world_size,
"NCCL communicator established across all ranks"
);
Ok(())
}
/// Validate the NCCL communicator: every rank `all_reduce`s a
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
/// `world_size`. Confirms the handshake is live, not just
/// configured.
///
/// Must be called after `init_nccl()`; before that the leader has
/// no Comm and the workers reply with `nccl_not_initialised`.
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
// 1. Trigger the all_reduce on every worker (write-only).
for w in &mut self.workers {
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
}
// 2. Leader's own all_reduce, in spawn_blocking. NCCL operations
// block until every rank participates.
let mut leader_state = std::mem::take(&mut self.leader_nccl);
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
let resp = leader_state.sanity_check();
(leader_state, resp)
})
.await
.context("leader NCCL sanity task panicked")?;
self.leader_nccl = returned_state;
let expected = self.world_size;
let leader_sum = match leader_resp {
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
}
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
};
if leader_sum != expected {
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
}
// 3. Read sanity result from each worker. All must match
// world_size — anything else means the collective didn't
// complete consistently across ranks.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
rpc::WorkerResponse::NcclSanityResult { observed_sum }
if observed_sum == expected => {}
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
anyhow::bail!(
"worker rank {} observed_sum={observed_sum}, expected {expected}",
w.rank
);
}
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
}
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
}
}
tracing::info!(
world_size = expected,
"NCCL sanity check OK across all ranks"
);
Ok(())
}
/// Ping every worker and return their Pong payloads in rank order.
/// Useful right after `spawn` to confirm the lifecycle plumbing is
/// intact before kicking off any heavier work.
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
let mut out = Vec::with_capacity(self.workers.len());
for w in &mut self.workers {
let resp = w.request(&WorkerRequest::Ping).await?;
match &resp {
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
WorkerResponse::Pong { rank, .. } => {
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
}
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
}
out.push(resp);
}
Ok(out)
}
/// Load this rank's shard of a dense Qwen3 model on every rank.
///
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
/// shards in their own address spaces and confirm via
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
/// of the full model (plus the replicated embedding/norm/lm_head).
///
/// `leader_device` is the candle `Device` the leader's shard lives
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
/// the same index passed to `init_nccl`. `dtype` is the on-device
/// element type; bf16 is the canonical Qwen3 distribution dtype.
///
/// `init_nccl` must have completed first. Bails if the leader's
/// NCCL comm isn't set up yet.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn load_dense_shard(
&mut self,
model_id: &str,
config_json: &str,
safetensors_paths: &[std::path::PathBuf],
leader_device: &candle_core::Device,
dtype: candle_core::DType,
quant: Option<String>,
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
use candle_nn::var_builder::ShardedSafeTensors;
use std::sync::Arc;
use tokio::sync::Mutex;
// Wrap the comm in SendComm immediately so it stays Send across
// the await points in this method — bare Arc<Comm> would
// poison the async fn's Send bound (Comm's raw NCCL pointer is
// !Send). The wrapper's safety contract is satisfied by the
// pool's outer Mutex serialising callers + the spawn_blocking
// thread being the only place ops are issued.
let leader_comm =
nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| {
anyhow::anyhow!("leader NCCL not initialised; call init_nccl first")
})?);
let world_size = self.world_size;
let safetensors_str: Vec<String> = safetensors_paths
.iter()
.map(|p| p.to_string_lossy().into_owned())
.collect();
// 1. Fan out the LoadDenseShard request to every worker without
// awaiting their replies — they'll build their shards in
// parallel with the leader below.
for w in &mut self.workers {
w.send_only(&WorkerRequest::LoadDenseShard {
model_id: model_id.to_string(),
config_json: config_json.to_string(),
safetensors_paths: safetensors_str.clone(),
quant: quant.clone(),
})
.await?;
}
// 2. Build rank 0's shard on the leader. Dispatch on model_type
// — for `qwen3` we build a `TpQwen3ForCausalLM`, for
// `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build
// `TpQwen3_5ForCausalLM`. Both end up wrapped in the
// `TpLeaderModel` enum so downstream callers don't care.
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
let device_for_leader = leader_device.clone();
let comm_for_leader = leader_comm;
let model_id_for_log = model_id.to_string();
let config_json_for_leader = config_json.to_string();
let quant_for_leader = quant.clone();
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
// SAFETY: same invariant as the single-GPU dense path —
// the HF cache files are treated as immutable while the
// mmap is held.
let vb = unsafe {
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
.context("build ShardedVarBuilder over safetensors")?
};
// SAFETY: as above — the HF cache files are immutable.
let mmap = unsafe {
candle_core::safetensors::MmapedSafetensors::multi(&paths_for_leader)
.context("build MmapedSafetensors for leader load")?
};
let comm = comm_for_leader.into_inner();
let loaded = match model_type.as_str() {
"qwen3" => {
let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3 Config JSON for leader load")?;
TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
&cfg, &vb, 0, world_size, comm,
)?)
}
"qwen3_5" => {
let cfg: super::tp::tp_qwen3_5::Config =
serde_json::from_str(&config_json_for_leader)
.context("parse Qwen3-Next Config JSON for leader load")?;
let quant_dtype =
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
cfg,
&vb,
&mmap,
0,
world_size,
comm,
quant_dtype,
)?)
}
other => anyhow::bail!(
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
),
};
tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)");
Ok(loaded)
})
.await
.context("leader load task panicked")??;
// 3. Collect worker confirmations. Anything other than
// LoadDenseShardOk aborts the whole load — the leader's
// already-loaded shard drops when this fn returns Err.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
WorkerResponse::LoadDenseShardOk => {}
WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
}
other => anyhow::bail!(
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
w.rank
),
}
}
Ok(Arc::new(Mutex::new(leader_model)))
}
/// Run one forward step across every rank. The leader's forward
/// returns the last-position logits as a candle Tensor on the
/// leader's device; the caller does sampling out-of-band. Workers
/// run their own forwards (the AllReduce inside row-parallel layers
/// is what lets the leader's collective complete) and reply with
/// `GenerateStepOk` — they do not ship logits over the wire.
///
/// `tokens` is the input for this step (prompt for prefill, the
/// previously-sampled token for decode). `offset` is the KV-cache
/// position before this step.
#[cfg(feature = "cuda")]
pub async fn generate_step(
&mut self,
model_id: &str,
leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
tokens: Vec<u32>,
offset: usize,
) -> Result<candle_core::Tensor> {
let step_start = std::time::Instant::now();
let tokens_len = tokens.len();
tracing::debug!(
model = %model_id,
tokens = tokens_len,
offset,
"WorkerPool::generate_step: fan-out"
);
// 1. Fan-out to workers.
for w in &mut self.workers {
w.send_only(&WorkerRequest::GenerateStep {
model_id: model_id.to_string(),
tokens: tokens.clone(),
offset,
})
.await?;
}
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
// inside the row-parallel layers block until every worker's
// forward issues the matching collective.
let leader_start = std::time::Instant::now();
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
let mut model = leader_model.blocking_lock();
let device = model.device().clone();
let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
// ForCausalLM::forward returns [B, 1, V] — squeeze both
// leading dims to the rank-1 vocab logits the sampler wants.
let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?;
Ok(logits)
})
.await
.context("leader forward task panicked");
let leader_ok = matches!(leader_result, Ok(Ok(_)));
tracing::debug!(
model = %model_id,
tokens = tokens_len,
leader_ms = leader_start.elapsed().as_millis(),
leader_ok,
"WorkerPool::generate_step: leader forward returned"
);
// 3. ALWAYS drain worker responses, regardless of whether the
// leader succeeded. Skipping this on the leader's error
// path leaves stale GenerateStepOk replies in the worker
// pipes that poison the NEXT request's recv (was seeing
// "ClearKvCache: expected KvCacheCleared, got
// GenerateStepOk" the call after any forward-time failure).
let drain_start = std::time::Instant::now();
let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::GenerateStepOk => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
other => Err(format!("expected GenerateStepOk, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
drain_ms = drain_start.elapsed().as_millis(),
errors = worker_errors.len(),
total_ms = step_start.elapsed().as_millis(),
"WorkerPool::generate_step: workers drained"
);
combine_leader_workers(leader_result, worker_errors, "GenerateStep")
}
/// Reset the KV cache for `model_id` on every rank. Called at the
/// start of every inference so a fresh request doesn't attend over
/// the previous one's tokens.
pub async fn clear_kv_cache(
&mut self,
model_id: &str,
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
) -> Result<()> {
let start = std::time::Instant::now();
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
for w in &mut self.workers {
w.send_only(&WorkerRequest::ClearKvCache {
model_id: model_id.to_string(),
})
.await?;
}
#[cfg(feature = "cuda")]
{
let mut m = leader_model.lock().await;
m.clear_kv_cache();
}
// Drain workers — same rationale as `generate_step`. The
// leader's clear_kv_cache is in-process and infallible, but we
// still always drain so an error on one worker doesn't leave
// pending responses for the others.
let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::KvCacheCleared => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
other => Err(format!("expected KvCacheCleared, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
errors = worker_errors.len(),
"WorkerPool::clear_kv_cache: workers drained"
);
if !worker_errors.is_empty() {
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
}
Ok(())
}
/// Drop this model's shards on every rank. The leader's shard is
/// expected to have been dropped by the caller (its `Arc` was held
/// in the TpLoadedModel and goes away when that's removed).
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
for w in &mut self.workers {
w.send_only(&WorkerRequest::UnloadModel {
model_id: model_id.to_string(),
})
.await?;
}
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
WorkerResponse::Unloaded => {}
WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
}
other => anyhow::bail!(
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
w.rank
),
}
}
Ok(())
}
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
/// children. Best-effort — individual worker failures are logged
/// but don't abort the rest of the sweep.
pub async fn shutdown(mut self) -> Result<()> {
for w in &mut self.workers {
match w.request(&WorkerRequest::Shutdown).await {
Ok(WorkerResponse::Bye) => {}
Ok(other) => tracing::warn!(
rank = w.rank,
response = ?other,
"expected Bye on shutdown"
),
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
}
}
for w in &mut self.workers {
match w.child.wait().await {
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
}
}
Ok(())
}
pub fn world_size(&self) -> u32 {
self.world_size
}
pub fn binary_path(&self) -> &PathBuf {
&self.exe
}
}

View File

@@ -1,293 +0,0 @@
//! NCCL state held by both the worker process and the leader's pool.
//!
//! Split into its own module so the worker (`tp/worker.rs`) and the
//! leader (`tp/mod.rs`) share the same hex-encoding/decoding code and
//! the same shape of `Option<Comm>` state machine.
//!
//! When the `cuda` feature is off, `NcclState` is a zero-sized
//! placeholder that returns `Error{kind="cuda_feature_not_enabled"}`
//! from every operation. When it's on, the same struct holds the
//! actual `cudarc::nccl::Comm`.
use super::rpc::WorkerResponse;
use super::worker::WorkerConfig;
/// Encode bytes as lowercase hex. Used for ferrying NCCL `Id::internal()`
/// across the leader→worker RPC boundary inside a JSON string.
pub fn encode_hex(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len() * 2);
for b in bytes {
use std::fmt::Write;
let _ = write!(out, "{b:02x}");
}
out
}
/// Decode lowercase-or-uppercase hex into bytes. Errors on odd length
/// or non-hex characters; the caller bubbles those up via the RPC's
/// `Error{kind="bad_request"}` variant.
pub fn decode_hex(s: &str) -> Result<Vec<u8>, String> {
if !s.len().is_multiple_of(2) {
return Err(format!("hex string has odd length {}", s.len()));
}
(0..s.len())
.step_by(2)
.map(|i| {
u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("bad hex byte at {i}: {e}"))
})
.collect()
}
#[cfg(not(feature = "cuda"))]
pub struct NcclState;
#[cfg(not(feature = "cuda"))]
impl Default for NcclState {
fn default() -> Self {
Self::new()
}
}
#[cfg(not(feature = "cuda"))]
impl NcclState {
pub fn new() -> Self {
Self
}
pub fn init(&mut self, _cfg: WorkerConfig, _comm_id_hex: &str) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "this neuron binary was built without --features cuda; \
NCCL Init requires CUDA"
.into(),
}
}
pub fn sanity_check(&mut self) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "NCCL sanity check requires --features cuda".into(),
}
}
}
#[cfg(feature = "cuda")]
mod cuda_impl {
use super::*;
use cudarc::driver::CudaContext;
use cudarc::nccl::{Comm, Id, ReduceOp};
use std::sync::Arc;
/// Number of bytes in NCCL's unique-id type; matches `Id::internal()`'s
/// `[c_char; 128]`. Wire-encoded as 256 lowercase hex chars.
const NCCL_ID_BYTES: usize = 128;
pub struct NcclState {
/// Wrapped in `Arc` so we can hand a clone to `TpQwen3ForCausalLM`
/// at load time (every row-parallel layer needs a reference to
/// run its trailing `AllReduce`). The `Arc` is the single source
/// of truth for the comm's lifetime — when the pool drops and
/// every layer that captured a clone drops, NCCL releases the
/// underlying `ncclComm_t`.
comm: Option<Arc<Comm>>,
/// Held alongside the Comm so the device isn't dropped
/// underneath the NCCL handle.
#[allow(dead_code)]
ctx: Option<Arc<CudaContext>>,
}
impl Default for NcclState {
fn default() -> Self {
Self::new()
}
}
impl NcclState {
pub fn new() -> Self {
Self {
comm: None,
ctx: None,
}
}
/// Clone the comm out as an `Arc` so callers (the leader-side
/// `TpQwen3ForCausalLM::load`, or the worker's own model load)
/// can hold a reference for the lifetime of the model. Returns
/// `None` before `init` has run.
pub fn comm(&self) -> Option<Arc<Comm>> {
self.comm.clone()
}
}
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
/// given comm must be serialised", not "the handle must stay on the
/// thread that created it" — so it's safe to move an `Arc<Comm>`
/// across threads as long as no concurrent ops are issued. The
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
/// wrapper at the move boundary is the only thing missing.
///
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
/// by the row-parallel layers are only used from the
/// `spawn_blocking` thread driving the forward pass; concurrent
/// access from another thread would still be a bug.
pub struct SendComm(pub Arc<Comm>);
// SAFETY: see the doc-comment above; the invariant is enforced at
// the call site (pool Mutex + single spawn_blocking thread), not at
// the type level.
unsafe impl Send for SendComm {}
unsafe impl Sync for SendComm {}
impl SendComm {
pub fn into_inner(self) -> Arc<Comm> {
self.0
}
}
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
// (libnccl-allocated state). NCCL requires that operations against
// one Comm be issued one at a time; we serialise access by storing
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
// move-safe — NCCL doesn't track the calling OS thread, only the
// stream the operations are dispatched against.
unsafe impl Send for NcclState {}
unsafe impl Sync for NcclState {}
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
/// the leader to mint the shared communicator id which is then
/// broadcast to every worker via the RPC `Init` message.
pub fn generate_comm_id_hex() -> Result<String, String> {
// NcclError lacks a Display impl in cudarc 0.19.x — surface
// via Debug throughout this module.
let id = Id::new().map_err(|e| format!("Id::new(): {e:?}"))?;
let bytes_u8: [u8; NCCL_ID_BYTES] = std::array::from_fn(|i| id.internal()[i] as u8);
Ok(encode_hex(&bytes_u8))
}
impl NcclState {
pub fn init(&mut self, cfg: WorkerConfig, comm_id_hex: &str) -> WorkerResponse {
match try_init(self, cfg, comm_id_hex) {
Ok(()) => WorkerResponse::InitOk,
Err(msg) => WorkerResponse::Error {
kind: "nccl_init_failed".into(),
message: msg,
},
}
}
pub fn sanity_check(&mut self) -> WorkerResponse {
let Some(comm) = self.comm.as_ref() else {
return WorkerResponse::Error {
kind: "nccl_not_initialised".into(),
message: "sanity_check requires Init to have completed first".into(),
};
};
match try_sanity_check(comm.as_ref()) {
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
Err(msg) => WorkerResponse::Error {
kind: "nccl_sanity_failed".into(),
message: msg,
},
}
}
}
fn try_init(state: &mut NcclState, cfg: WorkerConfig, comm_id_hex: &str) -> Result<(), String> {
let bytes = decode_hex(comm_id_hex)?;
if bytes.len() != NCCL_ID_BYTES {
return Err(format!(
"comm_id is {} bytes, expected {NCCL_ID_BYTES}",
bytes.len()
));
}
let id_bytes: [std::ffi::c_char; NCCL_ID_BYTES] =
std::array::from_fn(|i| bytes[i] as std::ffi::c_char);
let id = Id::uninit(id_bytes);
let ctx = CudaContext::new(cfg.cuda_device as usize)
.map_err(|e| format!("CudaContext::new({}) failed: {e}", cfg.cuda_device))?;
let stream = ctx.default_stream();
let comm = Comm::from_rank(stream, cfg.rank as usize, cfg.world_size as usize, id)
.map_err(|e| {
format!(
"Comm::from_rank(rank={}, world={}) failed: {e:?}",
cfg.rank, cfg.world_size
)
})?;
state.ctx = Some(ctx);
// `Comm` is !Send + !Sync at the type level because it wraps a
// raw `ncclComm_t`. The `Arc` is fine in practice — we
// serialise operations through the pool's outer Mutex and the
// SendComm wrapper at thread-crossing boundaries enforces this
// at every move site. clippy's `arc_with_non_send_sync` lint
// can't see that invariant; allow once at the canonical
// construction site.
#[allow(clippy::arc_with_non_send_sync)]
{
state.comm = Some(Arc::new(comm));
}
Ok(())
}
fn try_sanity_check(comm: &Comm) -> Result<u32, String> {
let stream = comm.stream().clone();
let input = stream
.clone_htod(&[1u32])
.map_err(|e| format!("htod sentinel: {e}"))?;
let mut output = stream
.alloc_zeros::<u32>(1)
.map_err(|e| format!("alloc output: {e}"))?;
// cudarc::nccl::NcclError doesn't impl Display in 0.19.x —
// surface via Debug so we still see the variant + ncclResult
// code instead of a generic "{e}" failure.
comm.all_reduce(&input, &mut output, &ReduceOp::Sum)
.map_err(|e| format!("all_reduce: {e:?}"))?;
let result = stream
.clone_dtoh(&output)
.map_err(|e| format!("dtoh result: {e}"))?;
Ok(result[0])
}
}
#[cfg(feature = "cuda")]
pub use cuda_impl::{NcclState, SendComm, generate_comm_id_hex};
/// Non-cuda stub for the leader: returns a clear marker error rather
/// than letting `init_nccl` succeed vacuously.
#[cfg(not(feature = "cuda"))]
pub fn generate_comm_id_hex() -> Result<String, String> {
Err("cuda_feature_not_enabled: build with --features cuda".into())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hex_roundtrip() {
let original: Vec<u8> = (0u8..=255).collect();
let encoded = encode_hex(&original);
assert_eq!(encoded.len(), 512);
let decoded = decode_hex(&encoded).expect("decode");
assert_eq!(decoded, original);
}
#[test]
fn hex_decode_rejects_odd_length() {
assert!(decode_hex("a").is_err());
assert!(decode_hex("abc").is_err());
}
#[test]
fn hex_decode_rejects_non_hex() {
assert!(decode_hex("zz").is_err());
assert!(decode_hex("ab_d").is_err());
}
#[test]
fn hex_encode_is_lowercase_padded() {
assert_eq!(encode_hex(&[0x0a, 0xff]), "0aff");
}
}

View File

@@ -1,257 +0,0 @@
//! Wire protocol between the neuron leader process and its
//! `--worker` subprocesses.
//!
//! Every frame is one newline-delimited JSON object on the worker's
//! stdin (request) or stdout (response). Both directions are tagged
//! sum types from the start so new ops in Stage 7b/7c slot in without
//! breaking compatibility — no "14 message types and a version field"
//! drift later. Adding a new variant is the canonical way to evolve
//! the protocol; existing peers that don't recognise an op return
//! `WorkerResponse::Error { kind: "unknown_op", .. }`.
//!
//! The serialised shape uses `tag = "op"` so a request looks like:
//! {"op":"ping"}
//! {"op":"init","comm_id":"a1b2..."}
//! and a response:
//! {"op":"pong","rank":0,"world_size":2,"cuda_device":0}
//! {"op":"error","kind":"nccl_init_failed","message":"..."}
use serde::{Deserialize, Serialize};
/// Leader → worker. Worker handles one at a time; replies with exactly
/// one `WorkerResponse` per request.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum WorkerRequest {
/// Liveness probe. Worker replies with `Pong` containing its own
/// identity. Used by the leader to confirm the subprocess is up
/// and ready before kicking off any heavier work.
Ping,
/// One-shot NCCL communicator setup. The leader generates the
/// `comm_id` once (rank 0 of NCCL), broadcasts it to every worker
/// via this message, then every rank (leader included) calls
/// `Comm::from_rank` with the same id — NCCL blocks until all
/// `world_size` ranks check in. The hex-encoded bytes are the
/// canonical `cudarc::nccl::Id::internal()` content.
Init {
/// Hex-encoded NCCL id bytes (128 bytes → 256 hex chars).
comm_id: String,
},
/// Sanity check: after Init, every rank runs an `all_reduce` over
/// a sentinel value (`1u32`). The expected sum is `world_size`.
/// Worker replies with the observed value so the leader can verify
/// the NCCL handshake is genuinely live, not just configured.
NcclSanityCheck,
/// Load this rank's shard of a dense Qwen3 model from mmaped
/// safetensors. The same `safetensors_paths` list is sent to every
/// rank — the ShardedVarBuilder reads only the rank-local slice of
/// each tensor at materialisation time, so the worker's VRAM
/// footprint is `1 / world_size` of the full model (plus replicated
/// embedding/norm/lm_head).
LoadDenseShard {
/// Caller-supplied id for later `GenerateStep` / `UnloadModel`
/// lookups. Typically the HF model id verbatim.
model_id: String,
/// JSON-serialised `candle_transformers::models::qwen3::Config`
/// — the same blob the leader parsed from the HF cache's
/// `config.json`. Threaded through verbatim so the worker uses
/// identical hyperparameters.
config_json: String,
/// Absolute paths the worker should mmap. The same set on every
/// rank; ShardedVarBuilder slices into them per rank.
safetensors_paths: Vec<String>,
/// Optional in-situ quantization dtype (e.g. "q5k", "q8_0",
/// "q6k"). When set, each linear-layer weight is quantized
/// at load time to the named ggml format — saves ~3-5x vs
/// bf16/f16 at the cost of some accuracy. `None` keeps the
/// weights in the on-disk dtype (typically bf16).
#[serde(default)]
quant: Option<String>,
},
/// Run one forward step on this rank's loaded model. The worker
/// reaches into its NCCL Comm for the row-parallel `AllReduce`s
/// inside the model — and so blocks on every other rank issuing the
/// same op. The leader does *not* receive logits back over RPC; it
/// runs its own rank-0 forward in parallel and uses its own logits
/// for sampling.
GenerateStep {
model_id: String,
/// Input token ids for this step. For prefill, the whole prompt;
/// for decode, a single token. Identical on every rank.
tokens: Vec<u32>,
/// KV cache offset (count of tokens already in the cache before
/// this step).
offset: usize,
},
/// Reset the KV cache for this model on this rank. Sent at the
/// start of every inference so a fresh request doesn't accidentally
/// attend over the previous one's tokens.
ClearKvCache { model_id: String },
/// Drop this rank's shard for the given model. Releases the VRAM
/// the shard's weights occupied; subsequent `GenerateStep` calls
/// against the same `model_id` return an `Error`.
UnloadModel { model_id: String },
/// Worker should release resources and exit. Worker replies `Bye`
/// and then closes stdout / exits zero. The leader reaps the
/// child via the `tokio::process::Child` it kept.
Shutdown,
}
/// Worker → leader. Always exactly one of these per `WorkerRequest`.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum WorkerResponse {
/// Reply to `Ping`. Carries enough identity for the leader to log
/// what it actually got back.
Pong {
rank: u32,
world_size: u32,
cuda_device: u32,
},
/// Reply to `Init`. Empty payload — success is the absence of
/// `Error`. NCCL's internal blocking handshake means by the time
/// this comes back, every other rank has also reached
/// `Comm::from_rank`.
InitOk,
/// Reply to `NcclSanityCheck`. The observed sum after a single
/// `all_reduce(SUM, 1u32)` across all ranks. The leader checks
/// this matches `world_size`.
NcclSanityResult { observed_sum: u32 },
/// Reply to `LoadDenseShard`. Empty payload — success is the
/// absence of `Error`. By the time this comes back, the rank's
/// `TpQwen3ForCausalLM` is constructed in memory and ready for
/// `GenerateStep`.
LoadDenseShardOk,
/// Reply to `GenerateStep`. Empty payload — workers don't ship
/// logits over the wire. The leader uses its own rank-0 logits;
/// workers only need to confirm the collective completed.
GenerateStepOk,
/// Reply to `ClearKvCache`. Empty payload.
KvCacheCleared,
/// Reply to `UnloadModel`. Empty payload. The named model is no
/// longer present on this rank.
Unloaded,
/// Reply to `Shutdown`. Worker exits immediately after writing this.
Bye,
/// Any request can produce this instead of its dedicated success
/// variant. `kind` is a machine-readable category so the leader
/// can branch on failure mode without string-matching `message`.
Error {
/// Short tag — `nccl_init_failed`, `unknown_op`, etc.
kind: String,
/// Human-readable detail for logs.
message: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrip<T>(value: &T) -> T
where
T: Serialize + for<'de> Deserialize<'de>,
{
serde_json::from_str(&serde_json::to_string(value).expect("serialise"))
.expect("deserialise")
}
#[test]
fn request_ping_round_trip() {
let req = WorkerRequest::Ping;
let wire = serde_json::to_string(&req).unwrap();
assert_eq!(wire, r#"{"op":"ping"}"#);
match roundtrip(&req) {
WorkerRequest::Ping => {}
other => panic!("expected Ping, got {other:?}"),
}
}
#[test]
fn request_init_carries_hex_id() {
let req = WorkerRequest::Init {
comm_id: "deadbeef".into(),
};
let wire = serde_json::to_string(&req).unwrap();
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
}
#[test]
fn request_shutdown_round_trip() {
assert_eq!(
serde_json::to_string(&WorkerRequest::Shutdown).unwrap(),
r#"{"op":"shutdown"}"#
);
}
#[test]
fn response_pong_round_trip() {
let resp = WorkerResponse::Pong {
rank: 1,
world_size: 4,
cuda_device: 1,
};
let wire = serde_json::to_string(&resp).unwrap();
assert!(wire.contains(r#""op":"pong""#));
assert!(wire.contains(r#""rank":1"#));
assert!(wire.contains(r#""world_size":4"#));
match roundtrip(&resp) {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
assert_eq!(rank, 1);
assert_eq!(world_size, 4);
assert_eq!(cuda_device, 1);
}
other => panic!("expected Pong, got {other:?}"),
}
}
#[test]
fn response_error_carries_kind_and_message() {
let resp = WorkerResponse::Error {
kind: "nccl_init_failed".into(),
message: "could not bind device".into(),
};
let wire = serde_json::to_string(&resp).unwrap();
assert!(wire.contains(r#""op":"error""#));
assert!(wire.contains(r#""kind":"nccl_init_failed""#));
}
#[test]
fn response_sanity_result_round_trip() {
let resp = WorkerResponse::NcclSanityResult { observed_sum: 4 };
match roundtrip(&resp) {
WorkerResponse::NcclSanityResult { observed_sum } => {
assert_eq!(observed_sum, 4);
}
other => panic!("expected NcclSanityResult, got {other:?}"),
}
}
/// Unknown ops on the wire deserialise to an error rather than
/// silently mis-matching — confirms our `serde(tag = "op")`
/// configuration rejects unknowns instead of doing fuzzy matching.
#[test]
fn unknown_op_fails_to_parse() {
let result: Result<WorkerRequest, _> = serde_json::from_str(r#"{"op":"explode"}"#);
assert!(result.is_err(), "should reject unknown op, got {result:?}");
}
}

View File

@@ -1,283 +0,0 @@
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
//! and `Shard` sharding hints.
//!
//! candle reads only the rank's slice of each weight tensor from
//! safetensors via `view.slice(start..stop)` — no full-tensor host
//! materialisation. That's a memory-efficiency win over hand-rolled
//! "load full + narrow" sharding (which the earlier
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
//!
//! Two layer types:
//!
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
//! local matmul. The downstream consumer either accepts a sharded
//! activation (next layer is also column-parallel) or all-gathers.
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
//!
//! Both assume **no bias** — every Qwen3-family weight layout we
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
//! support is mechanical when a future model needs it; the design
//! choice would be: column-parallel shards the bias along dim 0;
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
//! sum carries it exactly once.
use anyhow::{Context, Result};
use candle_core::quantized::{GgmlDType, QMatMul, QTensor};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
use std::sync::Arc;
#[cfg(feature = "cuda")]
use super::all_reduce::AllReduce;
/// Linear primitive that holds either a plain `Linear` (bf16/f16/f32)
/// or a quantized `QMatMul` (Q4K/Q5K/Q6K/Q8_0/etc.).
///
/// Constructed via [`MaybeQuantLinear::from_weight`] — pass `None` to
/// keep the weight in its loaded dtype (no quantization), or
/// `Some(dtype)` to quantize at load time.
///
/// On the forward path the two arms dispatch identically: `Module::forward`
/// returns an output in the caller's input dtype (f32 fallback for the
/// quantized matmul). Subsequent ops don't need to know whether the
/// layer was quantized.
pub enum MaybeQuantLinear {
Plain(Linear),
Quant(QMatMul),
}
impl MaybeQuantLinear {
/// Build a linear from a loaded weight tensor. If `quant` is set,
/// the weight is quantized in-situ and stored as a `QMatMul`;
/// otherwise it's wrapped in a plain `Linear`.
pub fn from_weight(weight: Tensor, quant: Option<GgmlDType>) -> Result<Self> {
match quant {
Some(dtype) => {
let qt = QTensor::quantize(&weight, dtype).with_context(|| {
format!(
"QTensor::quantize to {dtype:?} for shape {:?}",
weight.shape()
)
})?;
let qmm = QMatMul::from_arc(Arc::new(qt))
.context("QMatMul::from_arc on freshly quantized weight")?;
Ok(Self::Quant(qmm))
}
None => Ok(Self::Plain(Linear::new(weight, None))),
}
}
}
/// Above this M (the product of all input dims except the last)
/// dispatch the quantized matmul through `QMatMul::forward_via_f16`,
/// which dequantizes the weight to f16 once and runs cuBLAS GEMM.
/// At or below this M the GGUF GEMV kernel inside
/// `QMatMul::forward` wins (it operates on quantized blocks directly
/// and accumulates in registers).
///
/// Empirical: at M=30 on Qwen3.6-27B / RTX 5090, forward_via_f16 was
/// slightly *slower* than the GGUF GEMV kernel — the per-call dequant
/// cost (~30 MB f16 written to global memory per linear × ~480 calls
/// per prefill) eats the cuBLAS GEMM speedup at small M. The
/// crossover where the GEMM scaling actually beats the fixed dequant
/// tax sits well above M=8.
///
/// 64 is a conservative crossover that keeps short-prompt prefills
/// on the GGUF kernel (where the per-call cost is comparable to the
/// f16 path but the dequant tax is zero) and only activates the
/// dequant-then-GEMM path for long prefills where the GEMM size
/// makes amortising worth it. A proper fix is either a dequant
/// cache or a fused dequant+gemm cuda kernel — both larger projects.
const QUANT_PREFILL_M_THRESHOLD: usize = 64;
impl Module for MaybeQuantLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
match self {
Self::Plain(l) => l.forward(x),
Self::Quant(qm) => {
// Decode vs prefill split. `M` is the "rows of x" the
// matmul will iterate over — every dim except the last
// (which is in_features). For decode (`seq_len == 1`
// with batch 1) M is 1; for prefill with L>>1 it's L
// (or B*L).
let dims = x.dims();
let m: usize = dims.iter().take(dims.len() - 1).product();
if m > QUANT_PREFILL_M_THRESHOLD {
// Prefill: dequantize the weight once into f16,
// then run a real cuBLAS-backed GEMM. The cost of
// the dequant is amortised across all M tokens.
// `forward_via_f16` handles the dtype round-trip
// internally (output matches input dtype).
return qm.forward_via_f16(x);
}
// Decode (M <= threshold): use the on-the-fly GGUF
// GEMV kernel via `QMatMul::forward`. That kernel
// requires f32 inputs (it accumulates in f32 from the
// dequantized quant blocks); cast in/out at the
// boundary.
let in_dtype = x.dtype();
let x_f32 = if in_dtype == candle_core::DType::F32 {
x.clone()
} else {
x.to_dtype(candle_core::DType::F32)?
};
let y = qm.forward(&x_f32)?;
if y.dtype() == in_dtype {
Ok(y)
} else {
y.to_dtype(in_dtype)
}
}
}
}
}
/// Helper to build a [`Shard`] hint for a given dimension.
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
Shard {
dim,
rank: rank as usize,
world_size: world_size as usize,
}
}
/// Output-dim sharded linear (column-parallel). Holds a
/// [`MaybeQuantLinear`] whose underlying weight is this rank's slice
/// of the full `[out_features, in_features]` tensor along dim 0.
pub struct ColumnParallelLinear {
inner: MaybeQuantLinear,
}
impl ColumnParallelLinear {
/// Load this rank's column-parallel slice from a
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
///
/// Backward-compatible variant — no in-situ quantization. For
/// quantized loads, use [`Self::load_with_quant`].
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, None)
}
/// Like [`Self::load`] but quantizes the per-rank weight in-situ
/// when `quant` is `Some(dtype)`. Saves ~3-5x vs bf16/f16.
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(0, rank, world_size))
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap column-parallel '{}'", vb.prefix()))?;
Ok(Self { inner })
}
}
impl Module for ColumnParallelLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
self.inner.forward(x)
}
}
/// Input-dim sharded linear (row-parallel).
///
/// Holds a sharded [`MaybeQuantLinear`] plus an `AllReduce` op the
/// forward chains after the local matmul to recover the full activation.
pub struct RowParallelLinear {
inner: MaybeQuantLinear,
#[cfg(feature = "cuda")]
all_reduce: AllReduce,
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
/// is a pair: the column output is sharded, the row input is
/// sharded, and the AllReduce gives back the full output. For
/// `world_size = 1` the AllReduce is a no-op so we skip it.
needs_reduce: bool,
}
impl RowParallelLinear {
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
///
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
/// `AllReduce` runs against. On CPU builds the parameter is
/// elided — forward returns the partial sum, which is the *wrong*
/// answer for inference but lets us compile-check the model.
///
/// Backward-compatible variant — no in-situ quantization. For
/// quantized loads, use [`Self::load_with_quant`].
#[cfg(feature = "cuda")]
pub fn load(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: std::sync::Arc<cudarc::nccl::Comm>,
) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, comm, None)
}
/// Like [`Self::load`] but quantizes the per-rank weight in-situ.
#[cfg(feature = "cuda")]
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: std::sync::Arc<cudarc::nccl::Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
Ok(Self {
inner,
all_reduce: AllReduce::new(comm),
needs_reduce: world_size > 1,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, None)
}
#[cfg(not(feature = "cuda"))]
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
Ok(Self {
inner,
needs_reduce: world_size > 1,
})
}
}
impl Module for RowParallelLinear {
/// Local matmul followed by an `AllReduce` (when `cuda` and
/// `world_size > 1`). On CPU or single-rank, returns the partial
/// output directly — which is *only* correct for `world_size == 1`.
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let local = self.inner.forward(x)?;
#[cfg(feature = "cuda")]
if self.needs_reduce {
return local.apply_op1_no_bwd(&self.all_reduce);
}
let _ = self.needs_reduce;
Ok(local)
}
}

View File

@@ -1,605 +0,0 @@
//! Tensor-parallel Qwen3 dense model.
//!
//! Mirrors `candle_transformers::models::qwen3` structurally, but with:
//!
//! - Attention's `q_proj` / `k_proj` / `v_proj` as
//! [`ColumnParallelLinear`] (output sharded along the head dimension —
//! per-rank `num_heads = total/world_size`, ditto for kv heads).
//! - Attention's `o_proj` as [`RowParallelLinear`] (input sharded; the
//! trailing `AllReduce` recovers the full activation).
//! - MLP's `gate_proj` / `up_proj` as [`ColumnParallelLinear`] (sharded
//! along `intermediate_size`).
//! - MLP's `down_proj` as [`RowParallelLinear`].
//! - `embed_tokens`, all `RmsNorm`s, and `lm_head` **replicated** on
//! every rank. The per-rank duplicate weight is bounded and lets us
//! skip the embedding all-gather and the lm-head column-shard +
//! all-gather; both are pure latency optimisations that don't change
//! correctness.
//! - `kv_cache` holds the per-rank slice of K/V already (because they
//! came out of a column-parallel projection). No cache resharding
//! needed across ranks.
//!
//! Divisibility requirement, checked at load time:
//!
//! - `num_attention_heads % world_size == 0`
//! - `num_key_value_heads % world_size == 0`
//! - `intermediate_size % world_size == 0`
//!
//! Anything else bails — the safetensors slice would lose data otherwise.
//! This is the same divisibility-bail pattern that landed in
//! `EricLBuehler/mistral.rs` PR #2054.
//!
//! Replicated tensors (norms, embedding, lm_head) are loaded by asking
//! the `ShardedVarBuilder` for the full tensor via `vb.get(shape, name)`
//! — which defaults to `Shard { world_size: 1 }` and falls through to
//! the unsharded backend path.
use anyhow::{Context, Result, bail};
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use candle_nn::{Activation, Embedding, Linear, RmsNorm, kv_cache::ConcatKvCache};
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;
#[cfg(feature = "cuda")]
use cudarc::nccl::Comm;
use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
pub use candle_transformers::models::qwen3::Config;
/// Replicated rotary-embedding lookup. Re-implementation of the
/// `pub(crate)` candle equivalent — we can't reach into the upstream
/// type, so the inv-freq / sin / cos construction lives here.
pub(crate) struct Qwen3RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl Qwen3RotaryEmbedding {
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
})
}
fn apply(
&self,
q: &Tensor,
k: &Tensor,
offset: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
/// Helper: load a replicated tensor by asking the ShardedVarBuilder for
/// the full tensor (world_size=1 hint falls through to SimpleBackend).
fn load_replicated<S: Into<candle_core::Shape>>(
vb: &ShardedVarBuilder,
shape: S,
name: &str,
) -> Result<Tensor> {
vb.get(shape, name)
.with_context(|| format!("load replicated '{}/{name}'", vb.prefix()))
}
fn load_rms_norm(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<RmsNorm> {
let weight = load_replicated(vb, size, "weight")?;
Ok(RmsNorm::new(weight, eps))
}
/// TP MLP. SwiGLU = `down(silu(gate(x)) * up(x))`.
pub(crate) struct TpQwen3MLP {
gate_proj: ColumnParallelLinear,
up_proj: ColumnParallelLinear,
down_proj: RowParallelLinear,
act_fn: Activation,
}
impl TpQwen3MLP {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
bail!(
"intermediate_size {} not divisible by world_size {}",
cfg.intermediate_size,
world_size
);
}
Ok(Self {
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?,
act_fn: cfg.hidden_act,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
bail!(
"intermediate_size {} not divisible by world_size {}",
cfg.intermediate_size,
world_size
);
}
Ok(Self {
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?,
act_fn: cfg.hidden_act,
})
}
}
impl Module for TpQwen3MLP {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = x.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
/// TP attention. Carries per-rank head counts and the q/k per-head
/// RmsNorms (which are replicated and operate on a flattened B*H*L
/// axis, so the same code path works irrespective of how H was split).
pub(crate) struct TpQwen3Attention {
q_proj: ColumnParallelLinear,
k_proj: ColumnParallelLinear,
v_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
q_norm: RmsNorm,
k_norm: RmsNorm,
local_num_heads: usize,
local_num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
local_hidden_size: usize,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl TpQwen3Attention {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
Self::load_inner(
cfg,
rotary_emb,
vb,
rank,
world_size,
#[cfg(feature = "cuda")]
comm,
)
}
#[cfg(not(feature = "cuda"))]
pub fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
) -> Result<Self> {
Self::load_inner(cfg, rotary_emb, vb, rank, world_size)
}
fn load_inner(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
#[cfg(feature = "cuda")] comm: Arc<Comm>,
) -> Result<Self> {
if cfg.use_sliding_window {
bail!("sliding window is not yet supported in the TP path");
}
if cfg.attention_bias {
bail!("attention_bias=true is not supported by ColumnParallel/RowParallelLinear yet");
}
let ws = world_size as usize;
if !cfg.num_attention_heads.is_multiple_of(ws) {
bail!(
"num_attention_heads {} not divisible by world_size {}",
cfg.num_attention_heads,
world_size
);
}
if !cfg.num_key_value_heads.is_multiple_of(ws) {
bail!(
"num_key_value_heads {} not divisible by world_size {}",
cfg.num_key_value_heads,
world_size
);
}
let head_dim = cfg.head_dim;
let local_num_heads = cfg.num_attention_heads / ws;
let local_num_kv_heads = cfg.num_key_value_heads / ws;
let num_kv_groups = local_num_heads / local_num_kv_heads;
let local_hidden_size = head_dim * local_num_heads;
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?;
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?;
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?;
#[cfg(feature = "cuda")]
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?;
#[cfg(not(feature = "cuda"))]
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?;
let q_norm = load_rms_norm(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
let k_norm = load_rms_norm(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
// dim=2 because we cat along the seq axis of (B, H, L, D) tensors.
let kv_cache = ConcatKvCache::new(2);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
local_num_heads,
local_num_kv_heads,
num_kv_groups,
head_dim,
local_hidden_size,
rotary_emb,
kv_cache,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
// 1. Projections (column-parallel → output is sharded).
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
// 2. Reshape: (B, L, H, D) → (B, H, L, D).
let q = q
.reshape((b, l, self.local_num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
// 3. Per-head RmsNorm (replicated weight, flat input).
let q_flat = q.flatten(0, 2)?;
let k_flat = k.flatten(0, 2)?;
let q_flat = self.q_norm.forward(&q_flat)?;
let k_flat = self.k_norm.forward(&k_flat)?;
let q = q_flat.reshape((b, self.local_num_heads, l, self.head_dim))?;
let k = k_flat.reshape((b, self.local_num_kv_heads, l, self.head_dim))?;
// 4. Rotary.
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
// 5. Accumulate KV.
let (k, v) = self.kv_cache.append(&k, &v)?;
// 6. GQA repeat_kv on the rank-local K/V.
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
// 7. Attention scores.
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?;
// 8. Output projection (row-parallel → AllReduce inside).
ctx.transpose(1, 2)?
.reshape((b, l, self.local_hidden_size))?
.apply(&self.o_proj)
}
pub fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
}
struct TpDecoderLayer {
self_attn: TpQwen3Attention,
mlp: TpQwen3MLP,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl TpDecoderLayer {
#[cfg(feature = "cuda")]
fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
let self_attn = TpQwen3Attention::load(
cfg,
rotary_emb,
&vb.pp("self_attn"),
rank,
world_size,
comm.clone(),
)?;
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?;
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let ln2 = load_rms_norm(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
#[cfg(not(feature = "cuda"))]
fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
) -> Result<Self> {
let self_attn =
TpQwen3Attention::load(cfg, rotary_emb, &vb.pp("self_attn"), rank, world_size)?;
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?;
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let ln2 = load_rms_norm(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
fn forward(
&mut self,
x: &Tensor,
mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let h = self.ln1.forward(x)?;
let h = self.self_attn.forward(&h, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
let h2 = h2.apply(&self.mlp)?;
x + h2
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
}
/// Base TP Qwen3 transformer — embedding, decoder stack, final norm.
/// The lm_head sits on top in [`TpQwen3ForCausalLM`].
pub struct TpQwen3Model {
embed_tokens: Embedding,
layers: Vec<TpDecoderLayer>,
norm: RmsNorm,
device: Device,
dtype: DType,
}
impl TpQwen3Model {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
let embed_vb = vb.pp("model.embed_tokens");
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let vb_l = vb.pp("model.layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(TpDecoderLayer::load(
cfg,
rotary.clone(),
&vb_l.pp(i),
rank,
world_size,
comm.clone(),
)?);
}
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
let embed_vb = vb.pp("model.embed_tokens");
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let vb_l = vb.pp("model.layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(TpDecoderLayer::load(
cfg,
rotary.clone(),
&vb_l.pp(i),
rank,
world_size,
)?);
}
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
pub fn embed_weight(&self) -> &Tensor {
self.embed_tokens.embeddings()
}
pub fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
/// TP Qwen3 with a (replicated) language-model head on top.
pub struct TpQwen3ForCausalLM {
base: TpQwen3Model,
lm_head: Linear,
}
impl TpQwen3ForCausalLM {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
let base = TpQwen3Model::load(cfg, vb, rank, world_size, comm)?;
let lm_head = build_lm_head(cfg, vb, &base)?;
Ok(Self { base, lm_head })
}
#[cfg(not(feature = "cuda"))]
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
let base = TpQwen3Model::load(cfg, vb, rank, world_size)?;
let lm_head = build_lm_head(cfg, vb, &base)?;
Ok(Self { base, lm_head })
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self.base.forward(input, offset)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
pub fn device(&self) -> &Device {
&self.base.device
}
pub fn dtype(&self) -> DType {
self.base.dtype
}
}
fn build_lm_head(cfg: &Config, vb: &ShardedVarBuilder, base: &TpQwen3Model) -> Result<Linear> {
if cfg.tie_word_embeddings {
Ok(Linear::new(base.embed_weight().clone(), None))
} else {
let weight = load_replicated(
&vb.pp("lm_head"),
(cfg.vocab_size, cfg.hidden_size),
"weight",
)?;
Ok(Linear::new(weight, None))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,502 +0,0 @@
//! Entry point for `neuron --worker`.
//!
//! The worker reads one newline-delimited JSON `WorkerRequest` from
//! stdin per loop iteration, dispatches synchronously, and writes
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
//! stderr so it doesn't collide with the RPC stream.
//!
//! NCCL operations (`Init`, `NcclSanityCheck`) and model lifecycle ops
//! (`LoadDenseShard`, `GenerateStep`, `ClearKvCache`, `UnloadModel`)
//! are real when built with the `cuda` feature; without it they reply
//! with `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
//! the difference between a misconfigured build and a genuine NCCL or
//! model failure.
use anyhow::Result;
use std::collections::HashMap;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use super::nccl_state::NcclState;
use super::rpc::{WorkerRequest, WorkerResponse};
#[cfg(feature = "cuda")]
use super::tp_qwen3::TpQwen3ForCausalLM;
#[cfg(feature = "cuda")]
use super::tp_qwen3_5::TpQwen3_5ForCausalLM;
/// Worker-side discriminator over the architectures we can load via
/// `LoadDenseShard`. Mirrors `super::TpLeaderModel` on the leader
/// side — the dispatch happens on the `model_type` extracted from the
/// config JSON.
#[cfg(feature = "cuda")]
enum WorkerModel {
Qwen3(TpQwen3ForCausalLM),
Qwen3_5(TpQwen3_5ForCausalLM),
}
#[cfg(feature = "cuda")]
impl WorkerModel {
fn forward(
&mut self,
input: &candle_core::Tensor,
offset: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
WorkerModel::Qwen3(m) => m.forward(input, offset),
WorkerModel::Qwen3_5(m) => m.forward(input, offset),
}
}
fn clear_kv_cache(&mut self) {
match self {
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
WorkerModel::Qwen3_5(m) => m.clear_kv_cache(),
}
}
fn device(&self) -> &candle_core::Device {
match self {
WorkerModel::Qwen3(m) => m.device(),
WorkerModel::Qwen3_5(m) => m.device(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct WorkerConfig {
pub rank: u32,
pub world_size: u32,
pub cuda_device: u32,
}
/// Drive the worker RPC loop until `Shutdown` or EOF on stdin.
pub async fn run(config: WorkerConfig) -> Result<()> {
tracing::info!(
rank = config.rank,
world_size = config.world_size,
cuda_device = config.cuda_device,
"tp worker starting"
);
let mut state = WorkerState::new(config);
let stdin = tokio::io::stdin();
let mut reader = BufReader::new(stdin).lines();
let mut stdout = tokio::io::stdout();
while let Some(line) = reader.next_line().await? {
if line.trim().is_empty() {
continue;
}
let req: WorkerRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
let resp = WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse {line:?}: {e}"),
};
write_response(&mut stdout, &resp).await?;
continue;
}
};
let resp = state.handle(req).await;
let is_bye = matches!(resp, WorkerResponse::Bye);
write_response(&mut stdout, &resp).await?;
if is_bye {
break;
}
}
tracing::info!(rank = config.rank, "tp worker exiting");
Ok(())
}
async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -> Result<()> {
let mut line = serde_json::to_string(resp)?;
line.push('\n');
stdout.write_all(line.as_bytes()).await?;
stdout.flush().await?;
Ok(())
}
/// One rank's local state. Owns the rank's NCCL communicator (via
/// `NcclState`) and the rank's shard of every loaded model.
struct WorkerState {
config: WorkerConfig,
nccl: NcclState,
/// Loaded model shards keyed by `model_id`. Each entry wraps the
/// rank's TP architecture handle (Qwen3 or Qwen3-Next) — the
/// column/row-parallel layers hold an `Arc<Comm>` cloned from
/// `nccl`. Cuda-only: the underlying types reference cudarc types
/// that don't exist without the cuda feature.
#[cfg(feature = "cuda")]
models: HashMap<String, WorkerModel>,
/// Placeholder so the non-cuda build keeps the same field name set
/// and `WorkerState::new` reads the same on both.
#[cfg(not(feature = "cuda"))]
#[allow(dead_code)]
models: HashMap<String, ()>,
}
impl WorkerState {
fn new(config: WorkerConfig) -> Self {
Self {
config,
nccl: NcclState::new(),
models: HashMap::new(),
}
}
async fn handle(&mut self, req: WorkerRequest) -> WorkerResponse {
match req {
WorkerRequest::Ping => WorkerResponse::Pong {
rank: self.config.rank,
world_size: self.config.world_size,
cuda_device: self.config.cuda_device,
},
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
WorkerRequest::LoadDenseShard {
model_id,
config_json,
safetensors_paths,
quant,
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant),
WorkerRequest::GenerateStep {
model_id,
tokens,
offset,
} => self.handle_generate_step(&model_id, tokens, offset),
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
WorkerRequest::Shutdown => WorkerResponse::Bye,
}
}
#[cfg(feature = "cuda")]
fn handle_load_dense_shard(
&mut self,
model_id: String,
config_json: String,
safetensors_paths: Vec<String>,
quant: Option<String>,
) -> WorkerResponse {
use crate::harness::arch::qwen3_5 as qwen3_5_arch;
use candle_core::{DType, Device};
use candle_nn::var_builder::ShardedSafeTensors;
use candle_transformers::models::qwen3 as qwen3_dense;
use std::path::PathBuf;
let quant_dtype = match parse_quant_string(quant.as_deref()) {
Ok(q) => q,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse quant: {e}"),
};
}
};
if self.models.contains_key(&model_id) {
return WorkerResponse::Error {
kind: "already_loaded".into(),
message: format!("model '{model_id}' already loaded on this rank"),
};
}
let comm = match self.nccl.comm() {
Some(c) => c,
None => {
return WorkerResponse::Error {
kind: "nccl_not_initialised".into(),
message: "LoadDenseShard requires Init to have completed first".into(),
};
}
};
// Peek at model_type so we know which architecture to build.
let model_type = serde_json::from_str::<serde_json::Value>(&config_json)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let device = match Device::new_cuda(self.config.cuda_device as usize) {
Ok(d) => d,
Err(e) => {
return WorkerResponse::Error {
kind: "cuda_unavailable".into(),
message: format!("Device::new_cuda({}) failed: {e}", self.config.cuda_device),
};
}
};
let paths: Vec<PathBuf> = safetensors_paths.into_iter().map(PathBuf::from).collect();
// SAFETY: same invariant as the single-GPU dense path — the HF
// cache files are treated as immutable while the mmap is held.
let vb = match unsafe { ShardedSafeTensors::var_builder(&paths, DType::BF16, &device) } {
Ok(v) => v,
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("ShardedSafeTensors::var_builder: {e}"),
};
}
};
// Separate mmap of the same paths for the direct fused-region
// loader in `fused_load`. Linux's page cache shares the
// underlying pages between the two mmaps; the cost is one
// extra set of safetensors-header parses.
let mmap = match unsafe { candle_core::safetensors::MmapedSafetensors::multi(&paths) } {
Ok(m) => m,
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("MmapedSafetensors::multi: {e}"),
};
}
};
let loaded = match model_type.as_str() {
"qwen3" => {
let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) {
Ok(c) => c,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse Qwen3 Config JSON: {e}"),
};
}
};
match TpQwen3ForCausalLM::load(
&cfg,
&vb,
self.config.rank,
self.config.world_size,
comm,
) {
Ok(m) => WorkerModel::Qwen3(m),
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("TpQwen3ForCausalLM::load: {e:#}"),
};
}
}
}
"qwen3_5" => {
let cfg: qwen3_5_arch::Config = match serde_json::from_str(&config_json) {
Ok(c) => c,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse Qwen3-Next Config JSON: {e}"),
};
}
};
match TpQwen3_5ForCausalLM::load(
cfg,
&vb,
&mmap,
self.config.rank,
self.config.world_size,
comm,
quant_dtype,
) {
Ok(m) => WorkerModel::Qwen3_5(m),
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("TpQwen3_5ForCausalLM::load: {e:#}"),
};
}
}
}
other => {
return WorkerResponse::Error {
kind: "unsupported_arch".into(),
message: format!(
"worker: unsupported model_type '{other}' (supported: qwen3, qwen3_5)"
),
};
}
};
self.models.insert(model_id.clone(), loaded);
tracing::info!(
rank = self.config.rank,
model = %model_id,
model_type = %model_type,
"loaded TP shard"
);
WorkerResponse::LoadDenseShardOk
}
#[cfg(not(feature = "cuda"))]
fn handle_load_dense_shard(
&mut self,
_model_id: String,
_config_json: String,
_safetensors_paths: Vec<String>,
_quant: Option<String>,
) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "LoadDenseShard requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_generate_step(
&mut self,
model_id: &str,
tokens: Vec<u32>,
offset: usize,
) -> WorkerResponse {
use candle_core::Tensor;
let Some(model) = self.models.get_mut(model_id) else {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
};
let device = model.device().clone();
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => {
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("build input tensor: {e}"),
};
}
};
let start = std::time::Instant::now();
tracing::debug!(
rank = self.config.rank,
model = %model_id,
tokens = tokens.len(),
offset,
"worker GenerateStep: forward starting"
);
// Drop the resulting logits — the leader uses its own copy from
// rank 0. The forward's value here is the NCCL collectives it
// issues, which let the leader's rank-0 forward make progress.
if let Err(e) = model.forward(&input, offset) {
tracing::warn!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
error = %e,
"worker GenerateStep: forward failed"
);
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("TP forward: {e}"),
};
}
tracing::debug!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
"worker GenerateStep: forward done"
);
WorkerResponse::GenerateStepOk
}
#[cfg(not(feature = "cuda"))]
fn handle_generate_step(
&mut self,
_model_id: &str,
_tokens: Vec<u32>,
_offset: usize,
) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "GenerateStep requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
let Some(model) = self.models.get_mut(model_id) else {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
};
model.clear_kv_cache();
WorkerResponse::KvCacheCleared
}
#[cfg(not(feature = "cuda"))]
fn handle_clear_kv_cache(&mut self, _model_id: &str) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "ClearKvCache requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_unload_model(&mut self, model_id: &str) -> WorkerResponse {
if self.models.remove(model_id).is_none() {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
}
tracing::info!(rank = self.config.rank, model = %model_id, "unloaded TP shard");
WorkerResponse::Unloaded
}
#[cfg(not(feature = "cuda"))]
fn handle_unload_model(&mut self, _model_id: &str) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "UnloadModel requires --features cuda".into(),
}
}
}
/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the
/// common ggml format names (case-insensitive). `None` and `Some("")`
/// both map to "no quantization".
///
/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`,
/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`,
/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore
/// is optional and the prefix is case-insensitive.
#[cfg(feature = "cuda")]
pub(crate) fn parse_quant_string(
s: Option<&str>,
) -> anyhow::Result<Option<candle_core::quantized::GgmlDType>> {
use candle_core::quantized::GgmlDType;
let s = match s {
Some(s) if !s.is_empty() => s,
_ => return Ok(None),
};
let normalised = s.to_ascii_lowercase().replace('_', "");
let dtype = match normalised.as_str() {
"q40" => GgmlDType::Q4_0,
"q41" => GgmlDType::Q4_1,
"q50" => GgmlDType::Q5_0,
"q51" => GgmlDType::Q5_1,
"q80" => GgmlDType::Q8_0,
"q81" => GgmlDType::Q8_1,
"q2k" => GgmlDType::Q2K,
"q3k" => GgmlDType::Q3K,
"q4k" | "q4km" => GgmlDType::Q4K,
"q5k" | "q5km" => GgmlDType::Q5K,
"q6k" => GgmlDType::Q6K,
"q8k" => GgmlDType::Q8K,
"f16" => GgmlDType::F16,
"bf16" => GgmlDType::BF16,
"f32" => GgmlDType::F32,
other => anyhow::bail!(
"unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \
q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)"
),
};
Ok(Some(dtype))
}

View File

@@ -1,7 +1,5 @@
pub mod api;
pub mod config;
pub mod cuda;
pub mod discovery;
pub mod harness;
pub mod health;
pub mod startup;

View File

@@ -1,66 +1,21 @@
use anyhow::{Context, Result};
use anyhow::Result;
use clap::Parser;
use neuron::{
api,
config::NeuronConfig,
discovery,
harness::{HarnessRegistry, tp},
health, startup,
};
use neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing_subscriber::EnvFilter;
/// Top-level CLI. The same binary runs as either the public neuron
/// daemon (default), a tensor-parallel worker subprocess (when
/// `--worker` is set, spawned by the leader on the same host), or a
/// one-shot TP NCCL handshake check (when `--tp-smoke` is set).
#[derive(Parser)]
#[command(name = "neuron")]
#[command(about = "Per-node daemon for cortex inference clusters")]
#[command(version)]
struct Args {
/// Run in tensor-parallel worker mode. The leader process spawns
/// one of these per non-zero NCCL rank and drives it over
/// newline-delimited JSON on stdin/stdout. Worker mode skips
/// discovery, the HTTP listener, and the health poller — it's a
/// pure RPC loop.
#[arg(long, default_value_t = false)]
worker: bool,
/// Run a one-shot TP smoke test: spawn `--tp-size - 1` worker
/// subprocesses on `--cuda-devices`, build the NCCL communicator,
/// run an `AllReduce` sanity check across every rank, and exit.
/// Used to validate the TP plumbing in isolation from model load
/// and inference. Diagnostic-only — not exposed through the daemon
/// HTTP API.
#[arg(long, default_value_t = false)]
tp_smoke: bool,
/// NCCL rank for worker mode. Ignored when `--worker` is not set.
#[arg(long, default_value_t = 0)]
rank: u32,
/// Total NCCL world size for worker mode or TP smoke mode.
#[arg(long, default_value_t = 1)]
tp_size: u32,
/// CUDA device index for worker mode. Ignored when `--worker` is
/// not set.
#[arg(long, default_value_t = 0)]
cuda_device: u32,
/// Comma-separated CUDA device indices for TP smoke mode (one per
/// rank, starting with rank 0). Must have `tp_size` entries.
#[arg(long, value_delimiter = ',')]
cuda_devices: Vec<u32>,
/// Port to listen on (overrides config file). Daemon mode only.
/// Port to listen on (overrides config file).
#[arg(short, long)]
port: Option<u16>,
/// Path to the neuron config file. Daemon mode only.
/// Path to the neuron config file.
#[arg(short, long, default_value = "neuron.toml")]
config: String,
}
@@ -68,7 +23,6 @@ struct Args {
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
@@ -77,85 +31,12 @@ async fn main() -> Result<()> {
let args = Args::parse();
if args.worker {
return tp::worker::run(tp::worker::WorkerConfig {
rank: args.rank,
world_size: args.tp_size,
cuda_device: args.cuda_device,
})
.await;
}
if args.tp_smoke {
return tp_smoke(args.tp_size, args.cuda_devices).await;
}
daemon(args).await
}
/// One-shot tensor-parallel handshake. Spawns N-1 worker subprocesses
/// (rank 0 stays in this process), builds the NCCL communicator across
/// the full world, runs an AllReduce sanity check, and shuts everyone
/// down. Output is plain log lines on stderr + a final summary on
/// stdout in `key=value` form so an outer script can parse it.
async fn tp_smoke(tp_size: u32, cuda_devices: Vec<u32>) -> Result<()> {
if tp_size < 2 {
anyhow::bail!("--tp-size must be at least 2 (got {tp_size})");
}
if cuda_devices.len() as u32 != tp_size {
anyhow::bail!(
"--cuda-devices must list exactly {tp_size} entries (got {})",
cuda_devices.len()
);
}
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
let leader_device = cuda_devices[0];
tracing::info!(
tp_size,
?cuda_devices,
binary = %exe.display(),
"tp-smoke: spawning worker pool"
);
let mut pool = tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices).await?;
tracing::info!("tp-smoke: pinging every worker");
let pongs = pool.ping_all().await?;
for p in &pongs {
tracing::info!(?p, "tp-smoke: pong");
}
tracing::info!(leader_device, "tp-smoke: initialising NCCL");
pool.init_nccl(leader_device).await?;
tracing::info!("tp-smoke: running AllReduce sanity check");
pool.nccl_sanity_check().await?;
tracing::info!("tp-smoke: shutting down pool");
pool.shutdown().await?;
println!("status=ok");
println!("tp_size={tp_size}");
println!(
"cuda_devices={}",
cuda_devices
.iter()
.map(|d| d.to_string())
.collect::<Vec<_>>()
.join(",")
);
Ok(())
}
async fn daemon(args: Args) -> Result<()> {
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
NeuronConfig::default()
});
let port = args.port.unwrap_or(cfg.port);
let bind_url = format!("http://localhost:{port}");
let start_time = Instant::now();
tracing::info!("running hardware discovery");
@@ -166,18 +47,9 @@ async fn daemon(args: Args) -> Result<()> {
"discovery complete"
);
// Build harness registry from config. In-process harnesses (candle)
// need to know neuron's own bind URL so they can return it from
// inference_endpoint.
let registry = HarnessRegistry::from_configs(&cfg.harnesses, &bind_url, &cfg.harness);
// Build harness registry from config.
let registry = HarnessRegistry::from_configs(&cfg.harnesses);
discovery_result.harnesses = registry.names();
let candle = registry.candle();
// Activation: load default models before binding the listener.
// Each load may take tens of seconds to several minutes depending
// on model size and HF cache state — keep TimeoutStartSec in the
// systemd unit generous enough to cover the slowest entry.
startup::load_default_models(&registry, &cfg.default_models).await;
let health_cache = Arc::new(health::HealthCache::new());
health_cache
@@ -193,24 +65,13 @@ async fn daemon(args: Args) -> Result<()> {
discovery: discovery_result,
health_cache,
registry: RwLock::new(registry),
candle,
});
let app = api::neuron_routes().with_state(Arc::clone(&state));
let app = api::neuron_routes().with_state(state);
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
tracing::info!("neuron listening on {addr}");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app)
.with_graceful_shutdown(startup::shutdown_signal())
.await?;
// Deactivation: serve has returned (graceful shutdown signal
// received and connections drained). Release CUDA contexts / VRAM
// by unloading every model before exiting; systemd's TimeoutStopSec
// bounds how long this phase may take.
let registry = state.registry.read().await;
startup::unload_all_models(&registry).await;
tracing::info!("shutdown complete");
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -1,97 +0,0 @@
//! Activation- and deactivation-time orchestration.
//!
//! Wired from `main.rs` around the HTTP listener — activation runs
//! before bind, deactivation runs after axum returns from its
//! graceful-shutdown future. Kept in its own module so the logic is
//! unit-testable without spinning up a full neuron process.
use crate::harness::HarnessRegistry;
use cortex_core::harness::ModelSpec;
use std::time::Instant;
use tokio::signal;
/// Load each spec sequentially against the registry, treating
/// individual failures as warnings rather than fatal errors.
///
/// VRAM contention makes parallel loads risky; the sequential path is
/// boring but correct. The function logs elapsed time per load so an
/// operator can see which model is hogging activation.
pub async fn load_default_models(registry: &HarnessRegistry, specs: &[ModelSpec]) {
if specs.is_empty() {
return;
}
tracing::info!(count = specs.len(), "loading default models");
for spec in specs {
let start = Instant::now();
match registry.load_model(spec).await {
Ok(()) => tracing::info!(
model = %spec.model_id,
elapsed_ms = start.elapsed().as_millis() as u64,
"loaded default model"
),
Err(e) => tracing::warn!(
model = %spec.model_id,
error = %e,
elapsed_ms = start.elapsed().as_millis() as u64,
"failed to load default model, continuing"
),
}
}
}
/// Future that resolves on SIGINT (Ctrl-C) or SIGTERM (systemd stop).
///
/// Wired into `axum::serve(...).with_graceful_shutdown(shutdown_signal())`
/// so the HTTP listener stops accepting new connections, lets in-flight
/// requests drain, and then yields control back to main for cleanup.
pub async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c().await.ok();
};
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("install SIGTERM handler")
.recv()
.await;
};
tokio::select! {
_ = ctrl_c => tracing::info!("received SIGINT, shutting down"),
_ = terminate => tracing::info!("received SIGTERM, shutting down"),
}
}
/// Unload every model currently registered. Called from `main.rs` after
/// axum's graceful shutdown future resolves, so CUDA contexts and VRAM
/// are released before the process exits rather than left to the OS to
/// reclaim. Per-model failures are logged and skipped — keep cleanup
/// going even when one harness is unhealthy.
pub async fn unload_all_models(registry: &HarnessRegistry) {
let listed = match registry.list_all_models().await {
Ok(m) => m,
Err(e) => {
tracing::warn!(error = %e, "failed to list models during shutdown");
return;
}
};
if listed.is_empty() {
return;
}
tracing::info!(count = listed.len(), "unloading models for shutdown");
for model in listed {
let start = Instant::now();
match registry.unload_model(&model.id).await {
Ok(()) => tracing::info!(
model = %model.id,
elapsed_ms = start.elapsed().as_millis() as u64,
"unloaded"
),
Err(e) => tracing::warn!(
model = %model.id,
error = %e,
"unload failed during shutdown"
),
}
}
}

View File

@@ -1,56 +0,0 @@
//! Activation-time behaviour: load_default_models continues past
//! individual failures so a single broken catalogue entry doesn't
//! prevent the rest of the fleet from starting.
use cortex_core::harness::{HarnessConfig, ModelSpec};
use neuron::config::HarnessSettings;
use neuron::harness::HarnessRegistry;
use neuron::startup;
#[tokio::test]
async fn test_load_default_models_skips_unknown_harness() {
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
// Both entries fail synchronously inside the registry — no network
// call escapes (the harness lookup mismatches before hf-hub is
// touched). The function should still return cleanly.
let specs = vec![
ModelSpec {
model_id: "model-a".into(),
harness: "no-such-harness".into(),
quant: None,
tensor_parallel: None,
devices: None,
},
ModelSpec {
model_id: "model-b".into(),
harness: "no-such-harness".into(),
quant: None,
tensor_parallel: None,
devices: None,
},
];
startup::load_default_models(&registry, &specs).await;
let listed = registry
.list_all_models()
.await
.expect("list_all_models should succeed");
assert!(
listed.is_empty(),
"no models should be loaded after failed entries"
);
}
#[tokio::test]
async fn test_load_default_models_empty_is_noop() {
let registry = HarnessRegistry::new();
startup::load_default_models(&registry, &[]).await;
}

View File

@@ -14,7 +14,6 @@ async fn spawn_neuron(discovery: DiscoveryResponse) -> String {
discovery,
health_cache,
registry: RwLock::new(registry),
candle: None,
});
let app = api::neuron_routes().with_state(state);
@@ -136,30 +135,56 @@ async fn test_models_empty_registry() {
assert!(body.as_array().unwrap().is_empty());
}
/// Verify the candle harness registers, list is empty by default, and a
/// load attempt for an obviously-bogus model id returns a 4xx error
/// without crashing the daemon. Real load/unload exercising actual GGUF
/// download is covered by `tests/candle_lifecycle.rs` (cuda-integration).
/// Spawn a mock mistral.rs backend and a neuron with the mistralrs harness
/// pointing at it, then test the full model lifecycle through neuron's API.
#[tokio::test]
async fn test_candle_harness_registers_and_rejects_bogus_model() {
async fn test_models_via_mistralrs_harness() {
use axum::routing::{get, post};
use axum::{Json, Router};
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
use serde_json::Value;
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:13131",
&HarnessSettings::default(),
);
// Mock mistral.rs backend.
let mock_app = Router::new()
.route(
"/v1/models",
get(|| async {
Json(json!({
"data": [
{"id": "test-model", "status": "loaded"},
{"id": "other-model", "status": "unloaded"}
]
}))
}),
)
.route(
"/v1/models/unload",
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
)
.route(
"/v1/models/reload",
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
);
let mock_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let mock_addr = mock_listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(mock_listener, mock_app).await.unwrap();
});
let mock_url = format!("http://{mock_addr}");
// Build neuron with mistralrs harness pointing at mock.
let registry = HarnessRegistry::from_configs(&[HarnessConfig {
name: "mistralrs".into(),
endpoint: Some(mock_url.clone()),
systemd_unit: None,
}]);
let candle = registry.candle();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle,
});
let app = api::neuron_routes().with_state(state);
@@ -172,6 +197,7 @@ async fn test_candle_harness_registers_and_rejects_bogus_model() {
let client = reqwest::Client::new();
// GET /models — should return models from mock mistralrs.
let resp = client
.get(format!("{neuron_url}/models"))
.send()
@@ -179,140 +205,45 @@ async fn test_candle_harness_registers_and_rejects_bogus_model() {
.unwrap();
assert_eq!(resp.status(), 200);
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
assert!(models.is_empty());
assert_eq!(models.len(), 2);
assert_eq!(models[0]["id"], "test-model");
assert_eq!(models[0]["harness"], "mistralrs");
assert_eq!(models[0]["status"], "loaded");
assert_eq!(models[1]["id"], "other-model");
assert_eq!(models[1]["status"], "unloaded");
// Sending a wrong-harness spec should be rejected synchronously
// without touching the network or the model registry.
// GET /models/test-model/endpoint — should return mock URL.
let resp = client
.get(format!("{neuron_url}/models/test-model/endpoint"))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200);
let body: serde_json::Value = resp.json().await.unwrap();
assert_eq!(body["url"], mock_url);
// POST /models/unload — should succeed.
let resp = client
.post(format!("{neuron_url}/models/unload"))
.json(&json!({"model_id": "test-model"}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200);
let body: serde_json::Value = resp.json().await.unwrap();
assert_eq!(body["status"], "unloaded");
// POST /models/load — should succeed.
let resp = client
.post(format!("{neuron_url}/models/load"))
.json(&json!({"model_id": "definitely/not-real", "harness": "not-candle"}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 400);
// Registry still empty.
let resp = client
.get(format!("{neuron_url}/models"))
.send()
.await
.unwrap();
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
assert!(models.is_empty());
}
/// `/v1/chat/completions` returns 503 when no candle harness is registered.
#[tokio::test]
async fn test_chat_completions_no_candle_harness() {
let registry = HarnessRegistry::new();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle: None,
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/chat/completions"))
.json(&json!({
"model": "anything",
"messages": [{"role": "user", "content": "hi"}]
"model_id": "test-model",
"harness": "mistralrs"
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 503);
}
/// `/v1/chat/completions` returns 404 when the requested model isn't loaded.
#[tokio::test]
async fn test_chat_completions_model_not_loaded() {
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
let candle = registry.candle();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle,
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/chat/completions"))
.json(&json!({
"model": "definitely/not-loaded",
"messages": [{"role": "user", "content": "hi"}]
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 404);
}
/// `/v1/chat/completions` with `stream: true` returns 404 when the
/// model isn't loaded — same surface as the non-streaming path. The
/// streaming code only kicks in once the model lookup succeeds.
#[tokio::test]
async fn test_chat_completions_streaming_model_not_loaded() {
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
let candle = registry.candle();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle,
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/chat/completions"))
.json(&json!({
"model": "definitely/not-loaded",
"messages": [{"role": "user", "content": "hi"}],
"stream": true
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 404);
assert_eq!(resp.status(), 200);
let body: serde_json::Value = resp.json().await.unwrap();
assert_eq!(body["status"], "loaded");
}

View File

@@ -1,87 +0,0 @@
//! Real model load/unload lifecycle through the candle harness.
//!
//! Gated behind the `cuda-integration` feature because it downloads a
//! real (small) GGUF from HuggingFace and materialises tensors on the
//! configured device. Run on a host with network access and either a
//! CUDA GPU (when built with `--features cuda`) or enough CPU RAM to
//! hold the model.
//!
//! Usage:
//! cargo test -p neuron --features cuda-integration --test candle_lifecycle
//!
//! Optional environment variables:
//! NEURON_TEST_MODEL_ID — HuggingFace repo to load (default: a small
//! public Qwen3 GGUF repo).
//! NEURON_TEST_QUANT — quant substring matched against GGUF
//! filenames (default: "Q4_K_M").
//! HF_HOME — HuggingFace cache directory.
#![cfg(feature = "cuda-integration")]
use cortex_core::harness::{HarnessConfig, ModelSpec};
use neuron::config::HarnessSettings;
use neuron::harness::HarnessRegistry;
use std::path::PathBuf;
#[tokio::test]
async fn test_candle_qwen3_load_unload_lifecycle() {
let _ = tracing_subscriber::fmt()
.with_test_writer()
.with_env_filter("info,neuron=debug")
.try_init();
let model_id = std::env::var("NEURON_TEST_MODEL_ID")
.unwrap_or_else(|_| "Qwen/Qwen3-0.6B-GGUF".to_string());
let quant = std::env::var("NEURON_TEST_QUANT").unwrap_or_else(|_| "Q4_K_M".to_string());
let mut settings = HarnessSettings::default();
if let Ok(home) = std::env::var("HF_HOME") {
settings.candle.hf_cache = Some(PathBuf::from(home));
}
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:13131",
&settings,
);
let spec = ModelSpec {
model_id: model_id.clone(),
harness: "candle".into(),
quant: Some(quant),
tensor_parallel: None,
devices: Some(vec![0]),
};
registry
.load_model(&spec)
.await
.expect("load_model should succeed");
let models = registry.list_all_models().await.expect("list_all_models");
assert_eq!(models.len(), 1, "expected exactly one loaded model");
assert_eq!(models[0].id, model_id);
assert_eq!(models[0].harness, "candle");
assert_eq!(models[0].status, "loaded");
let url = registry.inference_endpoint(&model_id).await;
assert_eq!(url, Some("http://localhost:13131".into()));
// Re-loading the same model should be rejected.
let again = registry.load_model(&spec).await;
assert!(again.is_err(), "second load should error");
registry
.unload_model(&model_id)
.await
.expect("unload_model should succeed");
let models = registry.list_all_models().await.expect("list_all_models");
assert!(models.is_empty(), "registry should be empty after unload");
// Unloading a model that isn't loaded should error.
let err = registry.unload_model(&model_id).await;
assert!(err.is_err(), "unload of missing model should error");
}

View File

@@ -1,32 +0,0 @@
//! Deactivation behaviour: unload_all_models tolerates an empty
//! registry and continues past per-model unload failures.
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
use neuron::harness::HarnessRegistry;
use neuron::startup;
#[tokio::test]
async fn test_unload_all_models_empty_registry_is_noop() {
let registry = HarnessRegistry::new();
startup::unload_all_models(&registry).await;
}
#[tokio::test]
async fn test_unload_all_models_with_no_loaded_models() {
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
startup::unload_all_models(&registry).await;
let listed = registry
.list_all_models()
.await
.expect("list_all_models should still succeed after shutdown cleanup");
assert!(listed.is_empty());
}

View File

@@ -1,145 +0,0 @@
//! Stage 7a-i: confirm the TP worker subprocess lifecycle round-trips.
//!
//! Spawns two worker subprocesses via the leader→worker stdio RPC,
//! pings each, and cleanly shuts them down. No CUDA required —
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
//! runs on any host the workspace builds on.
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
/// Path to the neuron binary built by cargo for this test process.
/// cargo populates `CARGO_BIN_EXE_neuron` at compile time for sibling-
/// binary tests; production paths in main.rs use `/proc/self/exe`.
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
/// Two workers (so we spawn one subprocess: rank 0 is in-process,
/// rank 1 is the child). Verify the spawned worker responds to Ping
/// with its own identity, then shut it down cleanly.
#[tokio::test]
async fn test_spawn_ping_shutdown() {
// cuda_devices: rank 0 → device 0 (leader, unused here),
// rank 1 → device 1 (worker; not actually opened in 7a-i).
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
.await
.expect("spawn worker pool");
let pongs = pool.ping_all().await.expect("ping all workers");
assert_eq!(pongs.len(), 1, "expected one Pong (rank 1 only)");
match &pongs[0] {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
assert_eq!(*rank, 1);
assert_eq!(*world_size, 2);
assert_eq!(*cuda_device, 1);
}
other => panic!("expected Pong, got {other:?}"),
}
pool.shutdown().await.expect("clean shutdown");
}
/// Three workers — exercise the loop in `ping_all` / `shutdown`.
#[tokio::test]
async fn test_spawn_three_workers() {
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2])
.await
.expect("spawn worker pool");
let pongs = pool.ping_all().await.expect("ping all workers");
assert_eq!(pongs.len(), 2, "expected two Pongs (ranks 1 and 2)");
for (i, resp) in pongs.iter().enumerate() {
match resp {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
let expected_rank = (i + 1) as u32;
assert_eq!(*rank, expected_rank);
assert_eq!(*world_size, 3);
assert_eq!(*cuda_device, expected_rank);
}
other => panic!("expected Pong, got {other:?}"),
}
}
pool.shutdown().await.expect("clean shutdown");
}
/// 7a-ii: without the cuda feature, Init must fail with a clear
/// `cuda_feature_not_enabled` marker rather than silently succeeding.
/// This is the local-dev-box test; the real NCCL handshake is exercised
/// by `tp_worker_lifecycle_cuda.rs` (gated on `cuda-integration`).
#[tokio::test]
async fn test_init_returns_cuda_feature_not_enabled_without_cuda() {
use neuron::harness::tp::rpc::WorkerRequest;
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
// Spawn a single worker by hand to send Init directly (the pool's
// public API doesn't expose Init yet — that lands in 7a-ii).
let mut child = Command::new(NEURON_BIN)
.arg("--worker")
.arg("--rank")
.arg("1")
.arg("--tp-size")
.arg("2")
.arg("--cuda-device")
.arg("1")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true)
.spawn()
.expect("spawn worker");
let mut stdin = child.stdin.take().expect("stdin");
let stdout = child.stdout.take().expect("stdout");
let mut lines = BufReader::new(stdout).lines();
let req = WorkerRequest::Init {
comm_id: "ff".repeat(128),
};
let mut payload = serde_json::to_string(&req).unwrap();
payload.push('\n');
stdin.write_all(payload.as_bytes()).await.unwrap();
stdin.flush().await.unwrap();
let reply = lines
.next_line()
.await
.expect("read line")
.expect("got line");
let resp: WorkerResponse = serde_json::from_str(&reply).expect("parse reply");
match resp {
WorkerResponse::Error { kind, .. } => {
#[cfg(feature = "cuda")]
{
// With cuda enabled the response depends on whether
// CUDA hardware is actually present. Accept either
// the success contract or a real NCCL failure.
let _ = kind;
}
#[cfg(not(feature = "cuda"))]
assert_eq!(kind, "cuda_feature_not_enabled");
}
WorkerResponse::InitOk => {
// Real NCCL succeeded — only possible with cuda feature
// AND a working NCCL stack AND another rank actually
// joining. Don't fail; just acknowledge.
#[cfg(not(feature = "cuda"))]
panic!("InitOk without cuda feature is impossible");
}
other => panic!("expected Error or InitOk, got {other:?}"),
}
// Clean shutdown.
stdin.write_all(b"{\"op\":\"shutdown\"}\n").await.unwrap();
stdin.flush().await.unwrap();
let _ = lines.next_line().await; // Bye
let _ = child.wait().await;
}

View File

@@ -1,43 +0,0 @@
//! Stage 7a-ii: real NCCL handshake across the worker pool.
//!
//! Gated behind the `cuda-integration` feature because it requires
//! libnccl AND multiple CUDA devices on the running host. Run on
//! beast (2× RTX 5090) via:
//!
//! cargo test -p neuron --features cuda-integration \
//! --test tp_worker_lifecycle_cuda
//!
//! Steps: spawn N-1 workers, call `init_nccl`, run `nccl_sanity_check`
//! (every rank `all_reduce`s `1u32` with Sum; expected total =
//! world_size), shut down cleanly.
#![cfg(feature = "cuda-integration")]
use neuron::harness::tp::WorkerPool;
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
#[tokio::test]
async fn test_init_and_sanity_check_two_ranks() {
let _ = tracing_subscriber::fmt()
.with_test_writer()
.with_env_filter("info,neuron=debug")
.try_init();
// 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1.
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
.await
.expect("spawn worker pool");
pool.ping_all().await.expect("pong all workers");
pool.init_nccl(0)
.await
.expect("init_nccl: NCCL handshake across all ranks");
pool.nccl_sanity_check()
.await
.expect("nccl_sanity_check: observed_sum == world_size on all ranks");
pool.shutdown().await.expect("clean shutdown");
}

View File

@@ -1,7 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<service>
<short>cortex</short>
<description>Cortex — inference gateway for multi-node GPU clusters</description>
<port protocol="tcp" port="31313"/>
<port protocol="tcp" port="31314"/>
</service>

View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<service>
<short>helexa-neuron</short>
<description>Neuron — per-node GPU discovery and harness daemon for cortex</description>
<port protocol="tcp" port="13131"/>
</service>

View File

@@ -10,22 +10,6 @@ Restart=on-failure
RestartSec=5
User=neuron
Group=neuron
# /var/lib/neuron is the neuron user's $HOME — hf-hub writes its
# default cache there (~/.cache/huggingface/hub). Without this directive
# systemd doesn't create the directory and hf-hub downloads fail with
# "fetch GGUF <file>: failed to create cache dir".
StateDirectory=neuron
StateDirectoryMode=0755
# Loading default_models from neuron.toml happens before the HTTP
# listener binds; large models can take many minutes to download and
# materialise on first activation. systemd's default TimeoutStartSec
# (90s) is far too short; allow 30 minutes.
TimeoutStartSec=1800s
# On stop, neuron drains in-flight requests then unloads every model
# to release CUDA contexts cleanly. Allow generous time for big-model
# unloads; systemd will SIGKILL after this bound.
TimeoutStopSec=120s
KillSignal=SIGTERM
[Install]
WantedBy=multi-user.target

View File

@@ -2,49 +2,28 @@
#
# Copy to /etc/cortex/models.toml and adjust for your environment.
# Describes how to serve each model. Cortex matches these profiles
# against discovered neuron topologies for placement decisions; the
# resulting `(catalogue × topology)` set is what `GET /v1/models`
# returns and what the router can cold-load on demand.
#
# Field reference:
# id - HuggingFace model id, exact match.
# harness - which engine handles inference (currently "candle").
# quant - GGUF quantisation tag for the file in the HF repo
# (e.g. "Q4_K_M"). Omit/empty for the dense
# safetensors path. TP requires dense.
# vram_mb - rough estimate; advisory only, not enforced.
# min_devices - GPU count this profile needs. TP profiles use
# the same value as the tensor-parallel size.
# min_device_vram_mb - each device must meet this VRAM floor for the
# neuron to be considered "feasible".
# pinned_on - optional whitelist of neuron names. Non-empty
# narrows feasibility to just those neurons and
# protects the model from LRU eviction there.
# against discovered neuron topologies for placement decisions.
# Tensor-parallel target — needs a neuron with at least 2 large GPUs.
# The example pins to a specific neuron name; adjust or remove the
# pinned_on entry for your own fleet.
[[models]]
id = "Qwen/Qwen3.6-27B"
harness = "candle"
vram_mb = 54000
min_devices = 2
min_device_vram_mb = 24000
pinned_on = ["your-multi-gpu-neuron"]
# Mid-size dense model — fits on any single GPU with ≥16 GB VRAM.
[[models]]
id = "Qwen/Qwen3-8B"
harness = "candle"
vram_mb = 18000
min_devices = 1
min_device_vram_mb = 16000
# Small GGUF quantised — runs on any small GPU.
[[models]]
id = "unsloth/Qwen3-0.6B-GGUF"
harness = "candle"
id = "your-org/large-model"
harness = "mistralrs"
quant = "Q4_K_M"
vram_mb = 500
vram_mb = 19000
min_devices = 2
min_device_vram_mb = 10000
pinned_on = ["gpu-large"]
[[models]]
id = "your-org/medium-model"
harness = "mistralrs"
quant = "Q6_K"
vram_mb = 12000
min_devices = 1
pinned_on = ["gpu-medium"]
[[models]]
id = "your-org/embedding-model"
harness = "mistralrs"
quant = "Q8_0"
vram_mb = 8000
min_devices = 1
min_device_vram_mb = 4000

View File

@@ -3,51 +3,14 @@
# Copy to /etc/neuron/neuron.toml and adjust for your environment.
#
# Environment variable overrides use NEURON_ prefix with __ separators:
# NEURON_PORT=13131
# NEURON_PORT=9090
port = 13131
port = 9090
# -- Harnesses ---------------------------------------------------------------
# Each [[harnesses]] entry enables an inference engine. Currently only
# "candle" is supported — it runs in-process and uses huggingface/candle
# for inference on local CUDA devices (or CPU when CUDA is unavailable).
# Each [[harnesses]] entry declares an inference engine managed by neuron.
[[harnesses]]
name = "candle"
# -- Candle harness settings -------------------------------------------------
# Optional tuning for the candle harness.
[harness.candle]
# HuggingFace cache directory for model weights.
#
# Resolution order (first hit wins):
# 1. `hf_cache` here in this file.
# 2. `HF_HUB_CACHE` env var — same convention as the Python
# `huggingface_hub` library, so an existing cache directory shared
# with other tooling can be reused without per-tool config.
# 3. `HF_HOME` env var (cache appended as `$HF_HOME/hub`).
# 4. hf-hub's default (`~/.cache/huggingface/hub`).
#
# For per-host overrides (e.g. one neuron has an SSD with prefetched
# weights), prefer a systemd drop-in over editing this file:
# /etc/systemd/system/neuron.service.d/local.conf:
# [Service]
# Environment=HF_HUB_CACHE=/archive/hf-cache
# hf_cache = "/var/lib/neuron/hf-cache"
# -- Default models ----------------------------------------------------------
# Models listed here are loaded automatically when the neuron service
# activates. Loading is sequential — a slow or failing entry doesn't
# block the rest of the fleet, but it does push out the time before
# neuron starts serving HTTP, so keep the list short. Operators can
# load additional models on demand via POST /models/load.
#
# Make sure data/neuron.service's TimeoutStartSec is generous enough to
# cover the slowest entry's first-time download + materialisation.
# [[default_models]]
# model_id = "Qwen/Qwen3-0.6B-GGUF"
# harness = "candle"
# quant = "Q4_K_M"
# devices = [0]
name = "mistralrs"
endpoint = "http://localhost:8080"
systemd_unit = "mistralrs.service"

View File

@@ -1,10 +1,7 @@
Name: helexa-neuron
Version: 0.1.16
Name: neuron
Version: 0.1.8
Release: 1%{?dist}
Summary: Per-node GPU discovery and harness management daemon for cortex
# Package name disambiguates from Fedora's existing "neuron" package
# (NEURON neural simulation environment from Yale). Binary, systemd
# unit, and system user are still called "neuron" for brevity.
License: GPL-3.0-or-later
URL: https://git.lair.cafe/helexa/cortex
@@ -24,7 +21,6 @@ BuildRequires: systemd-rpm-macros
Requires(pre): shadow-utils
Requires: systemd
Requires: firewalld-filesystem
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
# from our .service file and emits Requires: user(neuron)/group(neuron).
@@ -37,9 +33,8 @@ Provides: user(neuron)
%description
Neuron is a per-node daemon for cortex inference clusters. It discovers
local GPU hardware via nvidia-smi, runs in-process inference via
huggingface/candle, and exposes an HTTP API for model lifecycle
management (load, unload, list, inference endpoint).
local GPU hardware via nvidia-smi, manages inference harnesses (mistral.rs,
llama.cpp), and exposes an HTTP API for model lifecycle management.
%prep
%autosetup
@@ -60,7 +55,6 @@ cargo build --release -p neuron
install -Dm755 target/release/neuron %{buildroot}%{_bindir}/neuron
install -Dm644 data/neuron.service %{buildroot}%{_unitdir}/neuron.service
install -Dm644 data/neuron-sysusers.conf %{buildroot}%{_sysusersdir}/neuron.conf
install -Dm644 data/neuron-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/helexa-neuron.xml
install -dm755 %{buildroot}%{_sysconfdir}/neuron
install -Dm644 neuron.example.toml %{buildroot}%{_sysconfdir}/neuron/neuron.toml
@@ -82,20 +76,9 @@ install -Dm644 neuron.example.toml %{buildroot}%{_sysconfdir}/neuron/neuron.toml
%{_bindir}/neuron
%{_unitdir}/neuron.service
%{_sysusersdir}/neuron.conf
%{_prefix}/lib/firewalld/services/helexa-neuron.xml
%dir %{_sysconfdir}/neuron
%config(noreplace) %{_sysconfdir}/neuron/neuron.toml
%changelog
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
- chore: ignore local deploy script
- chore: move default ports out of common-collision ranges
- ci: drop actions/cache for cargo registry and target
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
- ci: publish both packages to a single helexa/helexa COPR project
- fix(rpm): rename neuron package to helexa-neuron
- ci: commit generated %changelog entries back to main
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
- Initial package

View File

@@ -1,106 +0,0 @@
# Prebuilt-binary spec for cortex.
#
# Unlike cortex.spec (which builds from source via cargo), this spec
# wraps a pre-built `cortex` binary produced by an upstream CI job and
# packages it for rpm.lair.cafe. The %build phase is a no-op.
#
# Required defines at rpmbuild time:
# cortex_version e.g. "0.1.16"
# cortex_prerelease e.g. "0.1.20260518140530.gitabcdef0"
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
# commit time (sec) commit sha
# (used as Release; the timestamp prefix
# keeps same-day builds strictly ordered.)
%global _build_id_links none
%global debug_package %{nil}
%global __strip /usr/bin/true
%{!?cortex_version: %global cortex_version 0.0.0}
%if 0%{?cortex_prerelease:1}
%global cortex_release %{cortex_prerelease}
%else
%global cortex_release 1
%endif
Name: cortex
Version: %{cortex_version}
Release: %{cortex_release}%{?dist}
Summary: Inference gateway for multi-node GPU clusters (prebuilt)
License: GPL-3.0-or-later
URL: https://git.lair.cafe/helexa/cortex
Source0: cortex
Source1: cortex.service
Source2: cortex-sysusers.conf
Source3: cortex-firewalld.xml
Source4: cortex.example.toml
Source5: models.example.toml
Source6: LICENSE
ExclusiveArch: x86_64
Requires(pre): shadow-utils
Requires: systemd
Requires: firewalld-filesystem
Provides: user(cortex)
%description
Cortex is a Rust reverse-proxy that sits in front of multiple neuron
inference daemons and presents a unified OpenAI and Anthropic
compatible API surface.
This package wraps a binary built upstream in CI; the source-build
spec (cortex.spec) remains available for stable releases.
%prep
cp %{SOURCE0} ./cortex
cp %{SOURCE1} .
cp %{SOURCE2} .
cp %{SOURCE3} .
cp %{SOURCE4} .
cp %{SOURCE5} .
cp %{SOURCE6} .
%build
# Already built in the upstream CI build job.
%install
install -Dm755 cortex %{buildroot}%{_bindir}/cortex
install -Dm644 cortex.service %{buildroot}%{_unitdir}/cortex.service
install -Dm644 cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
install -Dm644 cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
install -dm755 %{buildroot}%{_sysconfdir}/cortex
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
%pre
getent group cortex >/dev/null || groupadd -r cortex
getent passwd cortex >/dev/null || \
useradd -r -g cortex -d /var/lib/cortex -s /sbin/nologin \
-c "Cortex inference gateway" cortex
%post
%systemd_post cortex.service
%preun
%systemd_preun cortex.service
%postun
%systemd_postun_with_restart cortex.service
%files
%license LICENSE
%{_bindir}/cortex
%{_unitdir}/cortex.service
%{_sysusersdir}/cortex.conf
%{_prefix}/lib/firewalld/services/cortex.xml
%dir %{_sysconfdir}/cortex
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
%config(noreplace) %{_sysconfdir}/cortex/models.toml
%changelog
* Mon May 18 2026 Gitea Actions <actions@git.lair.cafe> - %{cortex_version}-%{cortex_release}
- Prerelease build from upstream CI binary.

View File

@@ -1,126 +0,0 @@
# Prebuilt-binary spec for helexa-neuron flavoured by CUDA compute capability.
#
# Unlike helexa-neuron.spec (which builds from source via cargo), this
# spec wraps a pre-built `neuron-{flavour}` binary produced by an
# upstream CI job and packages it for rpm.lair.cafe. The %build phase
# is a no-op.
#
# Required defines at rpmbuild time:
# neuron_version e.g. "0.1.16"
# neuron_flavour e.g. "ada", "blackwell" — matches the CI build
# matrix's compute_cap label.
# neuron_prerelease e.g. "0.1.20260518140530.gitabcdef0"
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
# commit time (sec) commit sha
# (used as Release; the timestamp prefix
# keeps same-day builds strictly ordered.)
#
# One flavour can be installed at a time on a given host; flavour
# packages Conflict with each other.
%global _build_id_links none
%global debug_package %{nil}
%global __strip /usr/bin/true
%{!?neuron_version: %global neuron_version 0.0.0}
%{!?neuron_flavour: %global neuron_flavour blackwell}
%if 0%{?neuron_prerelease:1}
%global neuron_release %{neuron_prerelease}
%else
%global neuron_release 1
%endif
Name: helexa-neuron-%{neuron_flavour}
Version: %{neuron_version}
Release: %{neuron_release}%{?dist}
Summary: Per-node GPU inference daemon (candle, %{neuron_flavour} flavour)
License: GPL-3.0-or-later
URL: https://git.lair.cafe/helexa/cortex
Source0: neuron-%{neuron_flavour}
Source1: neuron.service
Source2: neuron-sysusers.conf
Source3: neuron-firewalld.xml
Source4: neuron.example.toml
Source5: LICENSE
ExclusiveArch: x86_64
# Binary links against the CUDA runtime, cuDNN, NCCL, etc. Suppress
# auto-detected exact soname deps — users may have CUDA from various
# sources (rpmfusion, nvidia-direct) at different compatible versions;
# a runtime dlopen failure surfaces a clearer error than rpm dep
# resolution would.
%global __requires_exclude ^lib(cuda|cudart|cudnn|cublas|cublasLt|curand|nvrtc|nccl)
Requires(pre): shadow-utils
Requires: systemd
Requires: firewalld-filesystem
Provides: helexa-neuron = %{neuron_version}-%{neuron_release}
Provides: user(neuron)
# Mutual exclusion across flavours and the source-build variant.
Conflicts: helexa-neuron
Conflicts: helexa-neuron-ada
Conflicts: helexa-neuron-ampere
Conflicts: helexa-neuron-blackwell
# (The Conflicts: with self is filtered by rpm at install time.)
%description
Neuron is the per-node daemon for cortex inference clusters. It
discovers local GPU hardware via nvidia-smi, runs in-process
inference via huggingface/candle, and exposes an HTTP API for model
lifecycle management (load, unload, list, inference endpoint).
This is the %{neuron_flavour} flavour, built for that CUDA compute
capability. Install the flavour matching the GPUs on this host.
%prep
cp %{SOURCE0} ./neuron
cp %{SOURCE1} .
cp %{SOURCE2} .
cp %{SOURCE3} .
cp %{SOURCE4} .
cp %{SOURCE5} .
%build
# Already built in the upstream CI build job (with --features cuda).
%install
install -Dm755 neuron %{buildroot}%{_bindir}/neuron
install -Dm644 neuron.service %{buildroot}%{_unitdir}/neuron.service
install -Dm644 neuron-sysusers.conf %{buildroot}%{_sysusersdir}/neuron.conf
install -Dm644 neuron-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/helexa-neuron.xml
install -dm755 %{buildroot}%{_sysconfdir}/neuron
install -Dm644 neuron.example.toml %{buildroot}%{_sysconfdir}/neuron/neuron.toml
%pre
getent group neuron >/dev/null || groupadd -r neuron
getent passwd neuron >/dev/null || \
useradd -r -g neuron -d /var/lib/neuron -s /sbin/nologin \
-G video,render \
-c "Neuron GPU node daemon" neuron
%post
%systemd_post neuron.service
%preun
%systemd_preun neuron.service
%postun
%systemd_postun_with_restart neuron.service
%files
%license LICENSE
%{_bindir}/neuron
%{_unitdir}/neuron.service
%{_sysusersdir}/neuron.conf
%{_prefix}/lib/firewalld/services/helexa-neuron.xml
%dir %{_sysconfdir}/neuron
%config(noreplace) %{_sysconfdir}/neuron/neuron.toml
%changelog
* Mon May 18 2026 Gitea Actions <actions@git.lair.cafe> - %{neuron_version}-%{neuron_release}
- Prerelease build from upstream CI binary (%{neuron_flavour} flavour).

View File

@@ -1 +0,0 @@
%_openpgp_sign_id @GPG_NAME@

View File

@@ -1,275 +0,0 @@
#!/bin/env bash
#
# Rolling deploy across the helexa fleet, driven by asset/manifest.yml.
# Installs / upgrades cortex on the gateway host and the appropriate
# helexa-neuron-<flavour> package on each neuron host, then restarts
# their services.
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
MANIFEST="${REPO_DIR}/asset/manifest.yml"
if [[ ! -f "${MANIFEST}" ]]; then
echo "fatal: manifest not found at ${MANIFEST}" >&2
exit 1
fi
# Parse the manifest with yq. NOTE: this expects the pip-installed yq
# (a jq wrapper using jq syntax) — `pip install yq`. The Fedora rpm
# `yq` is mikefarah/yq and uses different (yaml-native) syntax; if a
# host has that one instead these queries will fail.
cortex_host=$(yq -r '.cortex.host' "${MANIFEST}")
# Emit one TAB-separated 'host\tflavour' line per neuron.
mapfile -t neuron_entries < <(
yq -r '.neurons[] | .host + "\t" + .flavour' "${MANIFEST}"
)
# Return the installed package's "version-release" string, or
# "(not installed)" when rpm reports the package as absent. Capture
# rpm's output into a variable so its "package X is not installed"
# stdout message (rpm writes that to stdout, not stderr, when -q fails)
# doesn't leak into the result.
installed_nvr() {
local host="$1" pkg="$2"
local nvr
if nvr=$(ssh "${host}" "rpm -q --qf '%{version}-%{release}' ${pkg} 2>/dev/null"); then
echo "${nvr}"
else
echo "(not installed)"
fi
}
# Ensure the rpm.lair.cafe unstable repo is configured AND enabled on
# the remote host.
#
# The upstream .repo file at https://rpm.lair.cafe/lair-cafe-unstable.repo
# ships with `enabled=0` so a host that just fetched it won't start
# pulling unstable packages by accident. We have to explicitly flip
# enabled=1 via `dnf config-manager setopt`. Both addrepo and setopt
# are idempotent.
#
# Non-fatal — if either step fails the subsequent `dnf install` will
# surface a clearer diagnostic on its own.
ensure_lair_repo() {
local host="$1"
if ! ssh "${host}" "test -f /etc/yum.repos.d/lair-cafe-unstable.repo" 2>/dev/null; then
echo "[${host}] adding rpm.lair.cafe unstable repo"
if ! ssh "${host}" sudo dnf config-manager addrepo \
--from-repofile=https://rpm.lair.cafe/lair-cafe-unstable.repo \
>/dev/null 2>&1; then
echo "[${host}] WARNING: failed to add lair.cafe repo file (proceeding anyway)"
return 0
fi
fi
# The .repo file ships enabled=0; flip it on. Cheap, idempotent.
if ! ssh "${host}" sudo dnf config-manager setopt \
lair-cafe-unstable.enabled=1 >/dev/null 2>&1; then
echo "[${host}] WARNING: failed to enable lair-cafe-unstable (proceeding anyway)"
fi
}
# Ensure libcudnn.so.9 is resolvable on the remote host so the
# neuron binary (built with --features cudnn) doesn't fail at startup
# with "cannot open shared object file: No such file or directory".
#
# Probes ldconfig first — if cuDNN was installed manually (.tar/.run
# install), it'll be cached by ldconfig and we don't touch it.
# Otherwise adds NVIDIA's RHEL9 CUDA repo (the Fedora 43 CUDA repo
# doesn't ship cuDNN packages — only the RHEL9 one does) and installs
# libcudnn9-cuda-13.
ensure_cudnn_runtime() {
local host="$1"
if ssh "${host}" "ldconfig -p | grep -q libcudnn.so.9" 2>/dev/null; then
return 0
fi
echo "[${host}] installing cuDNN runtime"
if ! ssh "${host}" "test -f /etc/yum.repos.d/cuda-rhel9-x86_64.repo" 2>/dev/null; then
if ! ssh "${host}" sudo dnf config-manager addrepo \
--from-repofile=https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
>/dev/null 2>&1; then
echo "[${host}] WARNING: failed to add rhel9 CUDA repo (proceeding anyway)"
fi
fi
if ! ssh "${host}" sudo dnf install -y libcudnn9-cuda-13 >/dev/null 2>&1; then
echo "[${host}] WARNING: failed to install libcudnn9-cuda-13"
echo "[${host}] neuron may fail to start; install cuDNN manually if so"
fi
}
# True when the named package needs to be installed or upgraded on the
# remote host — either it's not present, or a newer version exists in
# the repo. False only when the installed version is current.
#
# `dnf check-update <pkg>` returns 0 when the package isn't installed
# at all (there's nothing to update), so we have to probe with rpm -q
# first to distinguish "absent" from "current". Other dnf failures
# collapse into "needs update" so the subsequent install step surfaces
# the real diagnostic rather than this check swallowing it.
needs_update() {
local host="$1" pkg="$2"
# Not installed → needs work.
if ! ssh "${host}" "rpm -q ${pkg}" >/dev/null 2>&1; then
return 0
fi
# Installed; ask dnf whether the repo has something newer.
if ssh "${host}" sudo dnf check-update --refresh -q "${pkg}" >/dev/null 2>&1; then
return 1
else
return 0
fi
}
# True if the named package is currently installed on the remote host.
# Used to decide between `dnf install` (fresh) and `dnf upgrade` (stale):
# dnf5's `install` is a no-op when the package is already present at
# any version — it does NOT auto-upgrade to the latest available — so
# the wrong command silently leaves the host on an old build.
is_installed() {
local host="$1" pkg="$2"
ssh "${host}" "rpm -q ${pkg}" >/dev/null 2>&1
}
# Install or upgrade the named package on the remote, picking the
# right dnf verb based on the installed-or-not state. Returns 0 with
# dnf's combined stdout/stderr captured in __DNF_OUTPUT__ on success,
# and 1 with the same captured output on failure.
__DNF_OUTPUT__=""
install_or_upgrade() {
local host="$1" pkg="$2"
local cmd
if is_installed "${host}" "${pkg}"; then
cmd="upgrade"
else
cmd="install"
fi
if __DNF_OUTPUT__=$(
ssh "${host}" sudo dnf "${cmd}" --refresh --allowerasing -y "${pkg}" 2>&1
); then
return 0
else
return 1
fi
}
# ---------------------------------------------------------------------------
# cortex (gateway)
# ---------------------------------------------------------------------------
ensure_lair_repo "${cortex_host}"
cortex_nvr=$(installed_nvr "${cortex_host}" cortex)
if needs_update "${cortex_host}" cortex; then
echo "[${cortex_host}] cortex update available (current: ${cortex_nvr})"
# Stop the service only if the unit file exists — fresh installs
# don't have it, and `systemctl stop` on a missing unit returns
# non-zero, which would otherwise short-circuit the install branch
# under set -e.
if ssh "${cortex_host}" "[ ! -f /usr/lib/systemd/system/cortex.service ] || sudo systemctl stop cortex.service"; then
echo "[${cortex_host}] stopped cortex service"
if install_or_upgrade "${cortex_host}" cortex; then
cortex_nvr=$(installed_nvr "${cortex_host}" cortex)
echo "[${cortex_host}] installed/upgraded cortex to ${cortex_nvr}"
else
echo "[${cortex_host}] failed to install/upgrade cortex:"
echo "${__DNF_OUTPUT__}" | sed "s/^/[${cortex_host}] /"
fi
else
echo "[${cortex_host}] failed to stop cortex service"
fi
else
echo "[${cortex_host}] cortex is up to date (${cortex_nvr})"
ssh "${cortex_host}" sudo systemctl stop cortex.service || true
fi
# Sync cortex.toml whether the package was upgraded or not — the config
# can change without a package bump.
if rsync \
--archive \
--compress \
--rsync-path 'sudo rsync' \
--chown root:root \
--chmod 644 \
"${REPO_DIR}/cortex.toml" \
"${cortex_host}:/etc/cortex/cortex.toml"; then
echo "[${cortex_host}] sync'd cortex.toml"
else
echo "[${cortex_host}] failed to sync cortex.toml"
fi
# Sync models.toml on the same lifecycle as cortex.toml — operator-owned,
# gitignored, drives /v1/models catalogue × topology resolution.
if [[ -f "${REPO_DIR}/models.toml" ]]; then
if rsync \
--archive \
--compress \
--rsync-path 'sudo rsync' \
--chown root:root \
--chmod 644 \
"${REPO_DIR}/models.toml" \
"${cortex_host}:/etc/cortex/models.toml"; then
echo "[${cortex_host}] sync'd models.toml"
else
echo "[${cortex_host}] failed to sync models.toml"
fi
else
echo "[${cortex_host}] no local models.toml — leaving /etc/cortex/models.toml untouched"
fi
ssh "${cortex_host}" sudo systemctl daemon-reload
if ssh "${cortex_host}" systemctl is-active --quiet cortex.service; then
echo "[${cortex_host}] cortex service is active"
elif ssh "${cortex_host}" sudo systemctl start cortex.service; then
echo "[${cortex_host}] started cortex service"
else
echo "[${cortex_host}] failed to start cortex service"
fi
# ---------------------------------------------------------------------------
# neuron (per-host, flavour from manifest)
# ---------------------------------------------------------------------------
for entry in "${neuron_entries[@]}"; do
IFS=$'\t' read -r neuron_host neuron_flavour <<< "${entry}"
package="helexa-neuron-${neuron_flavour}"
ensure_lair_repo "${neuron_host}"
ensure_cudnn_runtime "${neuron_host}"
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
if needs_update "${neuron_host}" "${package}"; then
echo "[${neuron_host}] ${package} update available (current: ${neuron_nvr})"
if ssh "${neuron_host}" "[ ! -f /usr/lib/systemd/system/neuron.service ] || sudo systemctl stop neuron.service"; then
echo "[${neuron_host}] stopped neuron service"
# --allowerasing lets dnf swap out a previously-installed
# bare helexa-neuron or a different flavour without manual
# intervention. The Conflicts: clauses in the spec ensure
# only one flavour is ever resident.
if install_or_upgrade "${neuron_host}" "${package}"; then
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
echo "[${neuron_host}] installed/upgraded ${package} to ${neuron_nvr}"
# Ensure firewalld allows neuron port
ssh "${neuron_host}" "sudo firewall-cmd --query-service=helexa-neuron --quiet 2>/dev/null || sudo firewall-cmd --add-service=helexa-neuron --permanent && sudo firewall-cmd --reload" 2>/dev/null || true
if ssh "${neuron_host}" "sudo systemctl daemon-reload && sudo systemctl start neuron.service"; then
echo "[${neuron_host}] started neuron service"
else
echo "[${neuron_host}] failed to start neuron service"
fi
else
echo "[${neuron_host}] failed to install ${package}:"
echo "${__DNF_OUTPUT__}" | sed "s/^/[${neuron_host}] /"
fi
else
echo "[${neuron_host}] failed to stop neuron service"
fi
else
echo "[${neuron_host}] ${package} is up to date (${neuron_nvr})"
if ssh "${neuron_host}" systemctl is-active --quiet neuron.service; then
echo "[${neuron_host}] neuron service is active"
elif ssh "${neuron_host}" sudo systemctl start neuron.service; then
echo "[${neuron_host}] started neuron service"
else
echo "[${neuron_host}] failed to start neuron service"
fi
fi
done

View File

@@ -1,154 +0,0 @@
#!/usr/bin/env python3
"""Parse RPM repodata and emit a packages.json manifest for the UI."""
import argparse
import gzip
import json
import os
import subprocess
import sys
import xml.etree.ElementTree as ET
from datetime import datetime, timezone
RPM_NS = "http://linux.duke.edu/metadata/common"
OTHER_NS = "http://linux.duke.edu/metadata/other"
REPO_NS = "http://linux.duke.edu/metadata/repo"
def find_repodata_file(repodata_dir, data_type):
"""Read repomd.xml and return the path to a specific data type's file."""
repomd_path = os.path.join(repodata_dir, "repomd.xml")
tree = ET.parse(repomd_path)
root = tree.getroot()
for data in root.findall(f"{{{REPO_NS}}}data"):
if data.get("type") == data_type:
location = data.find(f"{{{REPO_NS}}}location")
if location is not None:
href = location.get("href", "")
return os.path.join(os.path.dirname(repodata_dir), href)
return None
def open_compressed(path):
"""Open a gzip or zstd compressed file for reading."""
if path.endswith(".zst"):
result = subprocess.run(
["zstdcat", path], capture_output=True, check=True
)
import io
return io.BytesIO(result.stdout)
else:
return gzip.open(path, "rb")
def parse_primary(repodata_dir):
"""Parse primary.xml.{gz,zst} and return package metadata."""
path = find_repodata_file(repodata_dir, "primary")
if not path:
print("error: primary metadata not found in repomd.xml", file=sys.stderr)
sys.exit(1)
packages = {}
with open_compressed(path) as f:
tree = ET.parse(f)
for pkg in tree.getroot().findall(f"{{{RPM_NS}}}package"):
if pkg.get("type") != "rpm":
continue
name = pkg.findtext(f"{{{RPM_NS}}}name", "")
version_el = pkg.find(f"{{{RPM_NS}}}version")
ver = version_el.get("ver", "") if version_el is not None else ""
rel = version_el.get("rel", "") if version_el is not None else ""
arch = pkg.findtext(f"{{{RPM_NS}}}arch", "")
size_el = pkg.find(f"{{{RPM_NS}}}size")
size = int(size_el.get("package", "0")) if size_el is not None else 0
time_el = pkg.find(f"{{{RPM_NS}}}time")
build_time = int(time_el.get("build", "0")) if time_el is not None else 0
location_el = pkg.find(f"{{{RPM_NS}}}location")
filename = os.path.basename(location_el.get("href", "")) if location_el is not None else ""
key = f"{name}-{ver}-{rel}"
packages[key] = {
"name": name,
"version": ver,
"release": rel,
"arch": arch,
"summary": pkg.findtext(f"{{{RPM_NS}}}summary", ""),
"size": size,
"buildTime": build_time,
"rpmFilename": filename,
"changelog": [],
}
return packages
def parse_other(repodata_dir, packages):
"""Parse other.xml.gz and attach changelog entries to packages."""
path = find_repodata_file(repodata_dir, "other")
if not path:
return
with open_compressed(path) as f:
tree = ET.parse(f)
for pkg in tree.getroot().findall(f"{{{OTHER_NS}}}package"):
name = pkg.get("name", "")
version_el = pkg.find(f"{{{OTHER_NS}}}version")
ver = version_el.get("ver", "") if version_el is not None else ""
rel = version_el.get("rel", "") if version_el is not None else ""
key = f"{name}-{ver}-{rel}"
if key not in packages:
continue
for entry in pkg.findall(f"{{{OTHER_NS}}}changelog"):
packages[key]["changelog"].append({
"author": entry.get("author", ""),
"date": int(entry.get("date", "0")),
"text": (entry.text or "").strip(),
})
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--repodata-dir",
required=True,
help="path to the repodata/ directory",
)
parser.add_argument(
"--output",
required=True,
help="path to write packages.json",
)
parser.add_argument(
"--base-url",
required=True,
help="public base URL for the repo (e.g. https://rpm.lair.cafe/fedora/43/x86_64)",
)
args = parser.parse_args()
packages = parse_primary(args.repodata_dir)
parse_other(args.repodata_dir, packages)
manifest = {
"generated": datetime.now(timezone.utc).isoformat(),
"baseUrl": args.base_url,
"packages": list(packages.values()),
}
with open(args.output, "w") as f:
json.dump(manifest, f, indent=2)
print(f"wrote {len(packages)} packages to {args.output}")
if __name__ == "__main__":
main()

View File

@@ -1,60 +0,0 @@
#!/bin/env bash
#
# TP smoke test against a deployed neuron host.
#
# SSHes into the target host and runs `neuron --tp-smoke --tp-size N
# --cuda-devices ...` directly — no HTTP API involved. The smoke
# subcommand spawns N-1 worker subprocesses, joins them in an NCCL
# communicator, runs one AllReduce(Sum) of `1u32` across every rank, and
# verifies the observed sum equals world_size on every rank.
#
# This validates the lower-half of the TP stack (NCCL + IPC topology +
# subprocess lifecycle) without touching model load, inference, or HTTP.
# A failure here means the host cannot run any TP model and there is no
# point debugging the higher layers.
#
# Usage:
# script/tp-smoke.sh [host] [tp_size] [cuda_devices]
#
# Defaults:
# host = beast.hanzalova.internal (only fleet host with 2 GPUs)
# tp_size = 2
# cuda_devices = 0,1
set -euo pipefail
HOST="${1:-beast.hanzalova.internal}"
TP_SIZE="${2:-2}"
CUDA_DEVICES="${3:-0,1}"
say() { printf '[%s] %s\n' "${HOST}" "$*" >&2; }
die() { say "FAIL: $*"; exit 1; }
say "running neuron --tp-smoke --tp-size ${TP_SIZE} --cuda-devices ${CUDA_DEVICES}"
# Run as root via sudo because:
# - cuda contexts under a user account require either the nvidia
# uvm/peer devices to be world-readable or the user to be in a
# priviliged group (neither is true on stock fc43);
# - the installed binary lives at /usr/bin/neuron with no setuid;
# Running through root is the simplest path that matches how
# systemd-managed neuron sees the GPUs in production.
#
# The smoke command is read-only — it allocates a transient NCCL comm
# and a 1u32 buffer per rank, then tears it all down.
if ! ssh -o BatchMode=yes "${HOST}" \
sudo /usr/bin/neuron \
--tp-smoke \
--tp-size "${TP_SIZE}" \
--cuda-devices "${CUDA_DEVICES}" 2>&1 | tee /tmp/tp-smoke-"${HOST}".log
then
die "tp-smoke exited non-zero (see /tmp/tp-smoke-${HOST}.log)"
fi
# Final stdout line is `status=ok` on success.
if grep -q '^status=ok$' /tmp/tp-smoke-"${HOST}".log; then
say "PASS — NCCL handshake + AllReduce sanity check OK across ${TP_SIZE} ranks"
exit 0
else
die "no status=ok line in output"
fi

View File

@@ -1,188 +0,0 @@
#!/bin/env bash
#
# End-to-end smoke test for a deployed neuron.
#
# Confirms the daemon is reachable, loads a small public Qwen3 GGUF,
# fires a reasoning probe at /v1/chat/completions, and prints the
# answer. Used to validate the candle harness on a real GPU host
# before trusting it for production traffic, and as a regression test
# after pushing new neuron builds.
#
# Usage:
# script/validate-neuron.sh [host] [model_id] [quant] [tp_size]
#
# Defaults:
# host = beast.hanzalova.internal
# model_id = unsloth/Qwen3-0.6B-GGUF (official Qwen3-*-GGUF repos
# ship Q8_0 only; unsloth's mirror ships the full Q-spectrum
# including Q4_K_M)
# quant = Q4_K_M (empty = dense safetensors path)
# tp_size = unset (= 1 = single-GPU; pass 2 to drive the TP path)
set -euo pipefail
HOST="${1:-beast.hanzalova.internal}"
MODEL_ID="${2:-unsloth/Qwen3-0.6B-GGUF}"
# `${3-Q4_K_M}` (no colon) only uses the default when the arg is
# UNSET — passing an explicit empty string drives the dense path.
QUANT="${3-Q4_K_M}"
# tp_size > 1 forces the dense path (TP requires safetensors) and adds
# `tensor_parallel: N` to the load payload. The harness picks device
# indices 0..N-1 by default; override by passing NEURON_DEVICES="0,1,..."
# in the environment.
TP_SIZE="${4-1}"
PORT="${NEURON_PORT:-13131}"
BASE="http://${HOST}:${PORT}"
# Reasoning probe — concrete, low-temperature answer that small models
# can still get right. "Paris" is a strong signal of basic competence
# beyond gibberish.
PROBE_PROMPT='What is the capital of Georgia (Caucasus)? Respond with the city name only, no punctuation.'
EXPECT_SUBSTR='Tbilisi'
# Qwen3 prepends <think>...</think> reasoning before the answer when the
# chat template enables thinking mode, which eats most of a small token
# budget. 256 leaves enough room for thinking + final answer.
MAX_TOKENS=256
# /models/load is synchronous — neuron blocks the response until the
# hf-hub download + (GGUF parse or safetensors mmap) + tensor
# materialisation is done. Small GGUF (0.6B-Q4_K_M, ~400 MB) is
# typically a minute on a warm cache, several on a cold one. A
# Qwen3.6-class dense model is ~54 GB and can easily take an hour to
# download cold over a residential link, so default high. Override
# with NEURON_LOAD_TIMEOUT=N (seconds) for smaller targets if you'd
# rather fail fast.
LOAD_TIMEOUT="${NEURON_LOAD_TIMEOUT:-3600}"
INFER_TIMEOUT="${NEURON_INFER_TIMEOUT:-120}"
# Status messages go to stderr so command substitutions like
# `raw=$(run_probe)` capture only the function's intended return value
# (an HTTP body), not the progress chatter.
say() { printf '[%s] %s\n' "${HOST}" "$*" >&2; }
die() { say "FAIL: $*"; exit 1; }
probe_health() {
curl --silent --fail --max-time 5 "${BASE}/health" >/dev/null \
|| die "neuron not reachable at ${BASE}/health"
}
list_loaded_ids() {
# The manifest is YAML and uses yq; HTTP responses are JSON and use
# jq directly. pip-yq parses input as YAML by default, which trips
# on JSON content that happens to look like YAML aliases (chatcmpl
# ids, escaped quotes inside `<think>...</think>` blocks, etc.).
curl --silent --fail "${BASE}/models" | jq -r '.[].id'
}
is_loaded() {
list_loaded_ids 2>/dev/null | grep -Fxq "${MODEL_ID}"
}
trigger_load() {
# Build the per-rank CUDA device list as a JSON array. Either
# honour NEURON_DEVICES (`0,1,2`) verbatim or default to
# `[0, 1, ..., tp_size - 1]`.
local devices_json
if [[ -n "${NEURON_DEVICES:-}" ]]; then
devices_json=$(jq -n -c --arg s "${NEURON_DEVICES}" \
'$s | split(",") | map(tonumber)')
else
devices_json=$(jq -n -c --argjson n "${TP_SIZE}" '[range(0; $n)]')
fi
say "POST /models/load ${MODEL_ID} (quant=${QUANT:-<dense>}, tp=${TP_SIZE}, devices=${devices_json})"
say " (synchronous; may take a minute on first run while HF downloads)"
# Build the payload via jq so optional fields are omitted entirely
# when not in use. `tensor_parallel` is dropped when tp_size == 1;
# `quant` is dropped when empty. Both can coexist: tp_size > 1 +
# ISQ quant (q5k/q8_0/etc.) loads safetensors and quantizes the
# per-rank shard at load time. GGUF quants (Q4_K_M) are incompatible
# with TP — but the harness rejects that combination at load time
# rather than here.
local payload
local base
base=$(jq -n -c \
--arg id "${MODEL_ID}" \
--argjson devices "${devices_json}" \
'{model_id: $id, harness: "candle", devices: $devices}')
if [[ -n "${QUANT}" ]]; then
base=$(echo "${base}" | jq -c --arg q "${QUANT}" '. + {quant: $q}')
fi
if (( TP_SIZE > 1 )); then
base=$(echo "${base}" | jq -c --argjson tp "${TP_SIZE}" '. + {tensor_parallel: $tp}')
fi
payload="${base}"
# --write-out captures the response code on a separate line so we
# can surface a real diagnostic instead of relying on --fail.
local resp http_code body
resp=$(curl --silent --show-error --max-time "${LOAD_TIMEOUT}" \
--write-out '\n__HTTP__%{http_code}' \
-X POST "${BASE}/models/load" \
-H 'content-type: application/json' \
--data "${payload}") || die "curl /models/load failed: $?"
http_code=$(echo "${resp}" | grep -oP '(?<=__HTTP__)\d+$' | tail -1)
body=$(echo "${resp}" | sed '$ s/__HTTP__.*$//')
if [[ "${http_code}" != "200" ]]; then
die "load returned HTTP ${http_code}: ${body}"
fi
say "load returned ${http_code}: ${body}"
}
run_probe() {
say "POST /v1/chat/completions (probe: ${PROBE_PROMPT})"
local payload
payload=$(jq -n -c \
--arg model "${MODEL_ID}" \
--arg content "${PROBE_PROMPT}" \
--argjson tokens "${MAX_TOKENS}" \
'{
model: $model,
messages: [{role: "user", content: $content}],
temperature: 0.1,
max_tokens: $tokens
}')
local resp http_code body
resp=$(curl --silent --show-error --max-time "${INFER_TIMEOUT}" \
--write-out '\n__HTTP__%{http_code}' \
-X POST "${BASE}/v1/chat/completions" \
-H 'content-type: application/json' \
--data "${payload}") || die "curl /v1/chat/completions failed: $?"
http_code=$(echo "${resp}" | grep -oP '(?<=__HTTP__)\d+$' | tail -1)
body=$(echo "${resp}" | sed '$ s/__HTTP__.*$//')
if [[ "${http_code}" != "200" ]]; then
die "inference returned HTTP ${http_code}: ${body}"
fi
echo "${body}"
}
say "validating neuron at ${BASE}"
probe_health
say "/health OK"
if is_loaded; then
say "${MODEL_ID} already loaded"
else
trigger_load
fi
raw=$(run_probe)
echo "---"
# Dump the raw JSON. Don't pipe through `yq -r '.'` — yq's default
# YAML output mode chokes on JSON strings that contain `<` (and the
# `<think>` markers Qwen3 emits during reasoning are a perfect
# example). The targeted `yq -r '.path'` calls below work fine
# because jq's path filter mode bypasses the YAML re-emit.
echo "${raw}"
echo "---"
content=$(echo "${raw}" | jq -r '.choices[0].message.content // empty')
if [[ -z "${content}" ]]; then
die "no content in chat completion response"
fi
say "assistant said: ${content}"
if echo "${content}" | grep -qiF "${EXPECT_SUBSTR}"; then
say "PASS — response contains expected substring '${EXPECT_SUBSTR}'"
exit 0
else
die "response did not contain '${EXPECT_SUBSTR}'"
fi