Compare commits
100 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
aa88d37509
|
|||
|
0f00f72b47
|
|||
|
9b0ed0b57f
|
|||
|
dc2a803266
|
|||
|
e71181499e
|
|||
|
ee663e5e99
|
|||
|
34f9b77d9d
|
|||
|
f084aaab8e
|
|||
|
68a606a79c
|
|||
|
4aa71902d0
|
|||
|
bef159b21c
|
|||
|
8d7b099b36
|
|||
|
89d98d1fb2
|
|||
|
cc95fe28d9
|
|||
|
09c945f81e
|
|||
|
05dc0bad18
|
|||
|
10c151efa5
|
|||
|
44ae927e38
|
|||
|
1ebbe87651
|
|||
|
70eb6af42b
|
|||
|
d1a4aad91d
|
|||
|
95dc8745eb
|
|||
|
495d3f7c05
|
|||
|
5c4c8e0eba
|
|||
|
07c44d5db1
|
|||
|
e7eb3dab6a
|
|||
|
180274548d
|
|||
|
a70f317729
|
|||
|
c6022aa6b9
|
|||
|
9e31d8deca
|
|||
|
b400e8b704
|
|||
|
62ca125a68
|
|||
|
735945ee81
|
|||
|
f72dee094f
|
|||
|
d46d8d4f6c
|
|||
|
9b8bd146f6
|
|||
|
96d8755245
|
|||
|
12549c9aed
|
|||
|
46527d7804
|
|||
|
8d3194f992
|
|||
|
5436af9c73
|
|||
|
8e882c0757
|
|||
|
93421f48e2
|
|||
|
05e15f3597
|
|||
|
da068ded6d
|
|||
|
2a7ede0232
|
|||
|
18ae3c30ee
|
|||
|
1a0400131e
|
|||
|
1866b99a89
|
|||
|
60176e7c2e
|
|||
|
602e8e1471
|
|||
|
e9d0a75dd5
|
|||
|
6cf87e328f
|
|||
|
f9f5fa41b6
|
|||
|
ed4d71db09
|
|||
|
39010c779f
|
|||
|
57d7ef8d3c
|
|||
|
0e9671dd7d
|
|||
|
e29c9e35f0
|
|||
|
8a2334eacb
|
|||
|
aad314cdfa
|
|||
|
6779b7526a
|
|||
|
84f5662df1
|
|||
|
249c9442e8
|
|||
|
5e17081fb4
|
|||
|
03bed93fee
|
|||
|
4a5211d830
|
|||
|
6d2dc5ff1a
|
|||
|
b713dbe669
|
|||
|
5c957d08ec
|
|||
|
729317d1ef
|
|||
|
5c2bd1a1da
|
|||
|
3cccc2c56b
|
|||
|
7f797b0265
|
|||
|
5a0360c1d5
|
|||
|
472c0e8737
|
|||
|
|
b9d8e30058 | ||
|
25f75fe552
|
|||
|
3f94c50817
|
|||
|
3e1fb60076
|
|||
|
|
9bf987888c | ||
|
abe4ff7ccc
|
|||
|
7c3390a4e1
|
|||
|
2ff062da0e
|
|||
|
|
357f858a29 | ||
|
556e5293dc
|
|||
|
1d90238b01
|
|||
|
d99b25fb8a
|
|||
|
034da319f1
|
|||
|
|
7ece281617 | ||
|
3bb5b3c425
|
|||
|
|
9fa51ad874 | ||
|
9697fbae73
|
|||
|
|
2ce1060cb8 | ||
|
142e91c3f7
|
|||
|
|
52c8b4c983 | ||
|
4a9a4fc775
|
|||
|
53a3c1e157
|
|||
|
5c7d63c658
|
|||
|
|
f161412f91 |
@@ -1,61 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# Submit an SRPM to COPR, watch the build, and dump per-chroot build logs
|
|
||||||
# to stdout so they are captured in CI output.
|
|
||||||
#
|
|
||||||
# Usage: copr-build.sh <project> <srpm> [srpm...]
|
|
||||||
# Example: copr-build.sh helexa/cortex ./cortex-0.1.2-1.fc43.src.rpm
|
|
||||||
|
|
||||||
set -o pipefail
|
|
||||||
|
|
||||||
PROJECT="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
if [ -z "$PROJECT" ] || [ "$#" -eq 0 ]; then
|
|
||||||
echo "usage: $0 <project> <srpm> [srpm...]" >&2
|
|
||||||
exit 2
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Submit without waiting; capture the build ID from stdout.
|
|
||||||
SUBMIT_OUT=$(copr-cli build --nowait "$PROJECT" "$@")
|
|
||||||
echo "$SUBMIT_OUT"
|
|
||||||
BUILD_ID=$(echo "$SUBMIT_OUT" | grep -oP 'Created builds: \K[0-9]+' | head -n1)
|
|
||||||
|
|
||||||
if [ -z "$BUILD_ID" ]; then
|
|
||||||
echo "error: could not parse build ID from copr-cli output" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo
|
|
||||||
echo "Build $BUILD_ID submitted to $PROJECT"
|
|
||||||
echo "Follow live: https://copr.fedorainfracloud.org/coprs/build/$BUILD_ID"
|
|
||||||
echo
|
|
||||||
|
|
||||||
# Watch the build; captures status transitions to stdout. Exit non-zero
|
|
||||||
# on build failure, but defer propagating that until after we've fetched
|
|
||||||
# logs so the CI output contains diagnostics either way.
|
|
||||||
if copr-cli watch-build "$BUILD_ID"; then
|
|
||||||
STATUS=0
|
|
||||||
else
|
|
||||||
STATUS=$?
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Fetch per-chroot results (logs + rpms). Anonymous download — no auth needed.
|
|
||||||
mkdir -p copr-logs
|
|
||||||
copr-cli download-build --dest copr-logs "$BUILD_ID" || {
|
|
||||||
echo "warning: failed to download build artifacts" >&2
|
|
||||||
}
|
|
||||||
|
|
||||||
# Dump each chroot's builder-live.log as a collapsible group.
|
|
||||||
for chroot_dir in copr-logs/*/; do
|
|
||||||
[ -d "$chroot_dir" ] || continue
|
|
||||||
chroot=$(basename "$chroot_dir")
|
|
||||||
log="${chroot_dir}builder-live.log"
|
|
||||||
if [ -f "$log" ]; then
|
|
||||||
echo
|
|
||||||
echo "::group::${chroot} builder-live.log"
|
|
||||||
cat "$log"
|
|
||||||
echo "::endgroup::"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
exit "$STATUS"
|
|
||||||
342
.gitea/workflows/build-prerelease.yml
Normal file
342
.gitea/workflows/build-prerelease.yml
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
name: build-prerelease
|
||||||
|
|
||||||
|
# Manually-dispatched workflow that builds CUDA-flavoured neuron binaries
|
||||||
|
# (and a single cortex binary), packages each as a Fedora RPM, signs
|
||||||
|
# them, and publishes to the `unstable` channel at rpm.lair.cafe.
|
||||||
|
#
|
||||||
|
# Trigger from the Gitea UI: Actions → build-prerelease → Run workflow.
|
||||||
|
# Optionally provide a `ref` to build from a non-default branch.
|
||||||
|
#
|
||||||
|
# The published packages are versioned as e.g.
|
||||||
|
# helexa-neuron-blackwell-0.1.16-0.1.20260518T140530.gitabcdef0.fc43.x86_64
|
||||||
|
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
|
||||||
|
# commit time (s) commit sha
|
||||||
|
# so they sort BELOW the eventual 0.1.16-1 stable release, and so two
|
||||||
|
# commits on the same day are still strictly ordered by their commit
|
||||||
|
# timestamps (rather than by RPM-vercmp's alpha-vs-digit precedence
|
||||||
|
# on the SHA fragment).
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Auto-build on every push to main so the unstable channel tracks
|
||||||
|
# head without a manual dispatch step.
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
# Manual dispatch still available to build from a non-main ref.
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: "Git ref to build (branch / tag / commit). Defaults to the workflow's branch."
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
# Share the group with ci.yml so the two workflows can't run
|
||||||
|
# concurrently on the same `rust` runner (act reuses the workspace
|
||||||
|
# cache and races destroy each other's build files mid-compile).
|
||||||
|
# cancel-in-progress=false → workflows queue; if a newer push lands,
|
||||||
|
# the older run is still picked up by ci.yml's own ref-keyed
|
||||||
|
# concurrency (same group, queued).
|
||||||
|
group: cortex-runner-pool-${{ github.ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_INCREMENTAL: "0"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
prepare:
|
||||||
|
name: Resolve version stamps
|
||||||
|
runs-on: rust
|
||||||
|
outputs:
|
||||||
|
version: ${{ steps.info.outputs.version }}
|
||||||
|
release: ${{ steps.info.outputs.release }}
|
||||||
|
short_sha: ${{ steps.info.outputs.short_sha }}
|
||||||
|
commit_timestamp: ${{ steps.info.outputs.commit_timestamp }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- id: info
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml)
|
||||||
|
SHORT_SHA=$(git rev-parse --short=7 HEAD)
|
||||||
|
# Second-precise commit timestamp gives the release stamp a
|
||||||
|
# strictly monotonic numeric prefix. The earlier %Y%m%d-only
|
||||||
|
# form let same-day builds be ordered by RPM's rpmvercmp
|
||||||
|
# rules over the SHA, which is non-chronological — e.g.
|
||||||
|
# "git602e8e1" sorts newer than "gitf9f5fa4" purely because
|
||||||
|
# rpmvercmp ranks digit-prefixed segments above alpha ones.
|
||||||
|
# The SHA stays only as a debug identifier; sort order is
|
||||||
|
# decided entirely by the timestamp.
|
||||||
|
COMMIT_TIMESTAMP=$(git log -1 --format=%cd --date=format:%Y%m%d%H%M%S HEAD)
|
||||||
|
RELEASE="0.1.${COMMIT_TIMESTAMP}.git${SHORT_SHA}"
|
||||||
|
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "release=${RELEASE}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "commit_timestamp=${COMMIT_TIMESTAMP}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
build-cortex:
|
||||||
|
name: Build cortex binary
|
||||||
|
needs: prepare
|
||||||
|
# runner-rust image already provides rust/cargo/clippy/rustfmt via
|
||||||
|
# dnf — no rustup install step needed.
|
||||||
|
runs-on: rust
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Build cortex (release)
|
||||||
|
run: cargo build --release -p cortex-cli
|
||||||
|
|
||||||
|
- name: Stage binary
|
||||||
|
run: |
|
||||||
|
mkdir --parents artifacts
|
||||||
|
cp target/release/cortex artifacts/cortex
|
||||||
|
./artifacts/cortex --version || true
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: cortex-fc43
|
||||||
|
path: artifacts/cortex
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
|
build-neuron:
|
||||||
|
name: Build neuron-${{ matrix.flavour }}
|
||||||
|
needs: prepare
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- flavour: ampere
|
||||||
|
compute_cap: "86"
|
||||||
|
runner: cuda-13.0
|
||||||
|
cuda_home: /usr/local/cuda-13.0
|
||||||
|
build_jobs: 8
|
||||||
|
nvcc_threads: 4
|
||||||
|
cargo_features: "cuda cudnn flash-attn"
|
||||||
|
- flavour: ada
|
||||||
|
compute_cap: "89"
|
||||||
|
runner: cuda-13.0
|
||||||
|
cuda_home: /usr/local/cuda-13.0
|
||||||
|
build_jobs: 8
|
||||||
|
nvcc_threads: 4
|
||||||
|
cargo_features: "cuda cudnn flash-attn"
|
||||||
|
- flavour: blackwell
|
||||||
|
compute_cap: "120"
|
||||||
|
runner: cuda-13.0
|
||||||
|
cuda_home: /usr/local/cuda-13.0
|
||||||
|
build_jobs: 8
|
||||||
|
nvcc_threads: 4
|
||||||
|
cargo_features: "cuda cudnn flash-attn"
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Build neuron with CUDA (${{ matrix.flavour }})
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
export PATH="${{ matrix.cuda_home }}/bin:${PATH}"
|
||||||
|
export LD_LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LD_LIBRARY_PATH:-}"
|
||||||
|
export LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LIBRARY_PATH:-}"
|
||||||
|
cargo build --release -p neuron --features "${{ matrix.cargo_features }}"
|
||||||
|
env:
|
||||||
|
CUDA_COMPUTE_CAP: ${{ matrix.compute_cap }}
|
||||||
|
CARGO_BUILD_JOBS: ${{ matrix.build_jobs }}
|
||||||
|
NVCC_THREADS: ${{ matrix.nvcc_threads }}
|
||||||
|
|
||||||
|
- name: Stage binary
|
||||||
|
run: |
|
||||||
|
mkdir --parents artifacts
|
||||||
|
cp target/release/neuron artifacts/neuron-${{ matrix.flavour }}
|
||||||
|
file "artifacts/neuron-${{ matrix.flavour }}"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: neuron-${{ matrix.flavour }}-fc43
|
||||||
|
path: artifacts/neuron-${{ matrix.flavour }}
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
|
package-cortex:
|
||||||
|
name: Package cortex RPM
|
||||||
|
needs: [prepare, build-cortex]
|
||||||
|
runs-on: rpm
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
name: cortex-fc43
|
||||||
|
path: artifacts/
|
||||||
|
|
||||||
|
- name: Build RPM
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
rm -f ~/.rpmmacros
|
||||||
|
rpmdev-setuptree
|
||||||
|
cp artifacts/cortex ~/rpmbuild/SOURCES/
|
||||||
|
cp data/cortex.service ~/rpmbuild/SOURCES/
|
||||||
|
cp data/cortex-sysusers.conf ~/rpmbuild/SOURCES/
|
||||||
|
cp data/cortex-firewalld.xml ~/rpmbuild/SOURCES/
|
||||||
|
cp cortex.example.toml ~/rpmbuild/SOURCES/
|
||||||
|
cp models.example.toml ~/rpmbuild/SOURCES/
|
||||||
|
cp LICENSE ~/rpmbuild/SOURCES/
|
||||||
|
rpmbuild -bb rpm/cortex-prerelease.spec \
|
||||||
|
--define "cortex_version ${{ needs.prepare.outputs.version }}" \
|
||||||
|
--define "cortex_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||||
|
--undefine dist \
|
||||||
|
--define "dist .fc43"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: rpm-cortex-fc43
|
||||||
|
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
package-neuron:
|
||||||
|
name: Package helexa-neuron-${{ matrix.flavour }} RPM
|
||||||
|
needs: [prepare, build-neuron]
|
||||||
|
runs-on: rpm
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- flavour: ampere
|
||||||
|
- flavour: ada
|
||||||
|
- flavour: blackwell
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
name: neuron-${{ matrix.flavour }}-fc43
|
||||||
|
path: artifacts/
|
||||||
|
|
||||||
|
- name: Build RPM
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
rm -f ~/.rpmmacros
|
||||||
|
rpmdev-setuptree
|
||||||
|
cp artifacts/neuron-${{ matrix.flavour }} ~/rpmbuild/SOURCES/
|
||||||
|
cp data/neuron.service ~/rpmbuild/SOURCES/
|
||||||
|
cp data/neuron-sysusers.conf ~/rpmbuild/SOURCES/
|
||||||
|
cp data/neuron-firewalld.xml ~/rpmbuild/SOURCES/
|
||||||
|
cp neuron.example.toml ~/rpmbuild/SOURCES/
|
||||||
|
cp LICENSE ~/rpmbuild/SOURCES/
|
||||||
|
rpmbuild -bb rpm/helexa-neuron-prerelease.spec \
|
||||||
|
--define "neuron_version ${{ needs.prepare.outputs.version }}" \
|
||||||
|
--define "neuron_flavour ${{ matrix.flavour }}" \
|
||||||
|
--define "neuron_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||||
|
--undefine dist \
|
||||||
|
--define "dist .fc43"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: rpm-neuron-${{ matrix.flavour }}-fc43
|
||||||
|
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish to rpm.lair.cafe (unstable)
|
||||||
|
needs: [package-cortex, package-neuron]
|
||||||
|
runs-on: rpm
|
||||||
|
concurrency:
|
||||||
|
group: rpm-publish
|
||||||
|
cancel-in-progress: false
|
||||||
|
env:
|
||||||
|
RPM_REPO_HOST: oolon.kosherinata.internal
|
||||||
|
FEDORA_VERSION: "43"
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Download all built RPMs
|
||||||
|
uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
path: rpms/
|
||||||
|
pattern: rpm-*-fc43
|
||||||
|
|
||||||
|
- name: Flatten RPM artifacts
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
find rpms/ -name '*.rpm' -exec mv --target-directory=rpms/ {} +
|
||||||
|
find rpms/ -mindepth 1 -type d -empty -delete
|
||||||
|
ls -la rpms/
|
||||||
|
|
||||||
|
- name: Check for sequoia-sq
|
||||||
|
run: |
|
||||||
|
if ! command -v sq &> /dev/null; then
|
||||||
|
echo "ERROR: sequoia-sq is not installed. Install with: sudo dnf install sequoia-sq"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Import signing key
|
||||||
|
env:
|
||||||
|
# Pass secrets via env so values stay out of the rendered shell
|
||||||
|
# script (which Gitea includes in step logs). Template
|
||||||
|
# expansion of ${{ secrets.X }} inside `run:` writes the literal
|
||||||
|
# value into the script and depends on Gitea's log masker to
|
||||||
|
# scrub it — fragile for multi-line keys.
|
||||||
|
RPM_SIGNING_KEY: ${{ secrets.RPM_SIGNING_KEY }}
|
||||||
|
RPM_SIGNING_KEY_ID: ${{ secrets.RPM_SIGNING_KEY_ID }}
|
||||||
|
run: |
|
||||||
|
echo "$RPM_SIGNING_KEY" | gpg --batch --import
|
||||||
|
fpr=$(gpg --batch --with-colons --list-keys "$RPM_SIGNING_KEY_ID" | awk -F: '/^fpr:/ { print $10; exit }')
|
||||||
|
echo "${fpr}:6:" | gpg --batch --import-ownertrust
|
||||||
|
sed "s/@GPG_NAME@/$RPM_SIGNING_KEY_ID/" rpm/rpmmacros > ~/.rpmmacros
|
||||||
|
|
||||||
|
- name: Sign RPMs
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
for rpm in rpms/*.rpm; do
|
||||||
|
echo "signing ${rpm}..."
|
||||||
|
rpm --addsign "${rpm}"
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Set up SSH for rsync
|
||||||
|
run: |
|
||||||
|
install --directory --mode 700 ~/.ssh
|
||||||
|
echo "${RSYNC_SSH_KEY}" | install --mode 600 /dev/stdin ~/.ssh/id_ed25519
|
||||||
|
env:
|
||||||
|
RSYNC_SSH_KEY: ${{ secrets.RSYNC_SSH_KEY }}
|
||||||
|
|
||||||
|
- name: Test SSH connectivity
|
||||||
|
run: |
|
||||||
|
ssh -o StrictHostKeyChecking=accept-new "gitea_ci@${RPM_REPO_HOST}" exit
|
||||||
|
|
||||||
|
- name: Ensure unstable repo directory exists
|
||||||
|
run: |
|
||||||
|
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||||
|
"mkdir --parents /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||||
|
|
||||||
|
- name: Sync RPMs to unstable repo
|
||||||
|
run: |
|
||||||
|
rsync \
|
||||||
|
--archive \
|
||||||
|
--verbose \
|
||||||
|
--chmod D755,F644 \
|
||||||
|
rpms/*.rpm \
|
||||||
|
"gitea_ci@${RPM_REPO_HOST}:/var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/"
|
||||||
|
|
||||||
|
- name: Update unstable repo metadata
|
||||||
|
run: |
|
||||||
|
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||||
|
"cd /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable && createrepo_c --update ."
|
||||||
|
|
||||||
|
- name: Generate packages.json manifest
|
||||||
|
run: |
|
||||||
|
scp script/generate-packages-json.py "gitea_ci@${RPM_REPO_HOST}:/tmp/"
|
||||||
|
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||||
|
"python3 /tmp/generate-packages-json.py \
|
||||||
|
--repodata-dir /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/repodata \
|
||||||
|
--output /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/packages.json \
|
||||||
|
--base-url https://rpm.lair.cafe/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||||
@@ -7,6 +7,16 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
|
# Share a concurrency group with build-prerelease.yml so the two
|
||||||
|
# workflows don't race on the same `rust` runner workspace (act's
|
||||||
|
# /root/.cache/act/<hash>/hostexecutor/ is shared across concurrent
|
||||||
|
# jobs and one job's checkout step nukes another's in-flight build
|
||||||
|
# files). cancel-in-progress=false → they queue; same-ref pushes
|
||||||
|
# coalesce per workflow via cancel-in-progress on each.
|
||||||
|
concurrency:
|
||||||
|
group: cortex-runner-pool-${{ github.ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
env:
|
env:
|
||||||
CARGO_INCREMENTAL: "0"
|
CARGO_INCREMENTAL: "0"
|
||||||
RUSTC_WRAPPER: sccache
|
RUSTC_WRAPPER: sccache
|
||||||
@@ -16,56 +26,47 @@ env:
|
|||||||
SCCACHE_S3_USE_SSL: "false"
|
SCCACHE_S3_USE_SSL: "false"
|
||||||
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||||
|
# fmt, clippy, and test all run in parallel on the same `rust` runner
|
||||||
|
# and would otherwise share /root/.cache/act/<hash>/hostexecutor/target/,
|
||||||
|
# racing each other's cargo temp files (.tmpXXXXXX) and failing builds
|
||||||
|
# mid-compile. Give each job its own target directory so the invocations
|
||||||
|
# don't collide. sccache still backs the actual rustc cache, so the
|
||||||
|
# rebuild penalty is small.
|
||||||
|
CARGO_TARGET_DIR: target-${{ github.job }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check:
|
fmt:
|
||||||
name: Format, lint, build, test
|
name: Format
|
||||||
runs-on: fedora
|
runs-on: rust
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
- run: cargo fmt --check --all
|
||||||
|
|
||||||
- name: Cache cargo registry and target
|
clippy:
|
||||||
uses: actions/cache@v4
|
name: Clippy
|
||||||
with:
|
runs-on: rust
|
||||||
path: |
|
steps:
|
||||||
~/.cargo/bin
|
- uses: actions/checkout@v4
|
||||||
~/.cargo/registry/index
|
- run: cargo clippy --workspace -- -D warnings
|
||||||
~/.cargo/registry/cache
|
- run: sccache --show-stats
|
||||||
~/.cargo/git/db
|
|
||||||
target
|
|
||||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-cargo-
|
|
||||||
|
|
||||||
- name: Ensure sccache with S3 support
|
test:
|
||||||
env:
|
name: Test
|
||||||
RUSTC_WRAPPER: ""
|
runs-on: rust
|
||||||
run: |
|
steps:
|
||||||
if sccache --version 2>/dev/null && sccache --show-stats 2>/dev/null; then
|
- uses: actions/checkout@v4
|
||||||
echo "sccache with S3 support already installed"
|
- run: cargo test --workspace
|
||||||
else
|
- run: sccache --show-stats
|
||||||
cargo install sccache --features s3 --locked
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Check formatting
|
|
||||||
run: cargo fmt --check --all
|
|
||||||
|
|
||||||
- name: Clippy
|
|
||||||
run: cargo clippy --workspace -- -D warnings
|
|
||||||
|
|
||||||
- name: Test
|
|
||||||
run: cargo test --workspace
|
|
||||||
|
|
||||||
- name: Show sccache stats
|
|
||||||
run: sccache --show-stats
|
|
||||||
|
|
||||||
srpm-cortex:
|
srpm-cortex:
|
||||||
name: Build cortex SRPM
|
name: Build cortex SRPM
|
||||||
runs-on: fedora
|
runs-on: rpm
|
||||||
needs: check
|
needs: [fmt, clippy, test]
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Determine version
|
- name: Determine version
|
||||||
id: version
|
id: version
|
||||||
@@ -79,6 +80,12 @@ jobs:
|
|||||||
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
||||||
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
||||||
|
|
||||||
|
- name: Generate changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: cortex.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
- name: Generate source tarball
|
- name: Generate source tarball
|
||||||
run: |
|
run: |
|
||||||
set -ex
|
set -ex
|
||||||
@@ -113,11 +120,13 @@ jobs:
|
|||||||
|
|
||||||
srpm-neuron:
|
srpm-neuron:
|
||||||
name: Build neuron SRPM
|
name: Build neuron SRPM
|
||||||
runs-on: fedora
|
runs-on: rpm
|
||||||
needs: check
|
needs: [fmt, clippy, test]
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Determine version
|
- name: Determine version
|
||||||
id: version
|
id: version
|
||||||
@@ -129,31 +138,37 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
VERSION="${{ steps.version.outputs.VERSION }}"
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
||||||
sed -i "s/^Version:.*/Version: ${VERSION}/" neuron.spec
|
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
|
||||||
|
|
||||||
|
- name: Generate changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: helexa-neuron.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
- name: Generate source tarball
|
- name: Generate source tarball
|
||||||
run: |
|
run: |
|
||||||
set -ex
|
set -ex
|
||||||
VERSION="${{ steps.version.outputs.VERSION }}"
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
tar czf /tmp/neuron-${VERSION}.tar.gz \
|
tar czf /tmp/helexa-neuron-${VERSION}.tar.gz \
|
||||||
--transform "s,^\.,neuron-${VERSION}," \
|
--transform "s,^\.,helexa-neuron-${VERSION}," \
|
||||||
--exclude='./target' \
|
--exclude='./target' \
|
||||||
--exclude='./.git' \
|
--exclude='./.git' \
|
||||||
--exclude='*.tar.gz' \
|
--exclude='*.tar.gz' \
|
||||||
--exclude='*.src.rpm' \
|
--exclude='*.src.rpm' \
|
||||||
.
|
.
|
||||||
mv /tmp/neuron-${VERSION}.tar.gz .
|
mv /tmp/helexa-neuron-${VERSION}.tar.gz .
|
||||||
|
|
||||||
- name: Vendor Rust dependencies
|
- name: Vendor Rust dependencies
|
||||||
run: |
|
run: |
|
||||||
VERSION="${{ steps.version.outputs.VERSION }}"
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
cargo vendor vendor/
|
cargo vendor vendor/
|
||||||
tar czf neuron-${VERSION}-vendor.tar.gz vendor/
|
tar czf helexa-neuron-${VERSION}-vendor.tar.gz vendor/
|
||||||
rm -rf vendor/
|
rm -rf vendor/
|
||||||
|
|
||||||
- name: Build SRPM
|
- name: Build SRPM
|
||||||
run: |
|
run: |
|
||||||
rpmbuild -bs neuron.spec \
|
rpmbuild -bs helexa-neuron.spec \
|
||||||
--define "_sourcedir $(pwd)" \
|
--define "_sourcedir $(pwd)" \
|
||||||
--define "_srcrpmdir $(pwd)"
|
--define "_srcrpmdir $(pwd)"
|
||||||
|
|
||||||
@@ -165,65 +180,81 @@ jobs:
|
|||||||
|
|
||||||
copr-cortex:
|
copr-cortex:
|
||||||
name: Publish cortex to COPR
|
name: Publish cortex to COPR
|
||||||
runs-on: fedora
|
runs-on: fedora-43
|
||||||
needs: srpm-cortex
|
needs: srpm-cortex
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Download SRPM
|
- name: Download SRPM
|
||||||
uses: actions/download-artifact@v3
|
uses: actions/download-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: srpm-cortex
|
name: srpm-cortex
|
||||||
|
|
||||||
- name: Configure copr-cli
|
- name: Publish to COPR
|
||||||
run: |
|
uses: https://git.lair.cafe/actions/copr-publish@v1
|
||||||
mkdir -p ~/.config
|
with:
|
||||||
echo "${{ secrets.COPR_CONFIG }}" > ~/.config/copr
|
project: helexa/helexa
|
||||||
|
srpm: "*.src.rpm"
|
||||||
- name: Submit build to COPR
|
copr-config: ${{ secrets.COPR_CONFIG }}
|
||||||
run: bash .gitea/scripts/copr-build.sh helexa/cortex *.src.rpm
|
|
||||||
|
|
||||||
copr-neuron:
|
copr-neuron:
|
||||||
name: Publish neuron to COPR
|
name: Publish neuron to COPR
|
||||||
runs-on: fedora
|
runs-on: fedora-43
|
||||||
needs: srpm-neuron
|
needs: srpm-neuron
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Download SRPM
|
- name: Download SRPM
|
||||||
uses: actions/download-artifact@v3
|
uses: actions/download-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: srpm-neuron
|
name: srpm-neuron
|
||||||
|
|
||||||
- name: Configure copr-cli
|
- name: Publish to COPR
|
||||||
run: |
|
uses: https://git.lair.cafe/actions/copr-publish@v1
|
||||||
mkdir -p ~/.config
|
with:
|
||||||
echo "${{ secrets.COPR_CONFIG }}" > ~/.config/copr
|
project: helexa/helexa
|
||||||
|
srpm: "*.src.rpm"
|
||||||
- name: Submit build to COPR
|
copr-config: ${{ secrets.COPR_CONFIG }}
|
||||||
run: bash .gitea/scripts/copr-build.sh helexa/neuron *.src.rpm
|
|
||||||
|
|
||||||
bump-version:
|
bump-version:
|
||||||
name: Bump version in source
|
name: Bump version in source
|
||||||
runs-on: fedora
|
runs-on: rust
|
||||||
needs: [copr-cortex, copr-neuron]
|
needs: [copr-cortex, copr-neuron]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Stamp version and push
|
- name: Determine version
|
||||||
|
id: version
|
||||||
|
run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Stamp version
|
||||||
|
run: |
|
||||||
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
|
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
||||||
|
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
||||||
|
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
|
||||||
|
cargo check --workspace 2>/dev/null || true
|
||||||
|
|
||||||
|
- name: Generate cortex changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: cortex.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
|
- name: Generate helexa-neuron changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: helexa-neuron.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
|
- name: Commit and push
|
||||||
env:
|
env:
|
||||||
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
|
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
VERSION="${GITHUB_REF#refs/tags/v}"
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
|
||||||
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
|
||||||
sed -i "s/^Version:.*/Version: ${VERSION}/" neuron.spec
|
|
||||||
cargo check --workspace 2>/dev/null || true
|
|
||||||
git config user.name "Gitea Actions"
|
git config user.name "Gitea Actions"
|
||||||
git config user.email "actions@git.lair.cafe"
|
git config user.email "actions@git.lair.cafe"
|
||||||
git add Cargo.toml Cargo.lock cortex.spec neuron.spec
|
git add Cargo.toml Cargo.lock cortex.spec helexa-neuron.spec
|
||||||
if git diff --cached --quiet; then
|
if git diff --cached --quiet; then
|
||||||
echo "Version already at ${VERSION}"
|
echo "Nothing to commit for ${VERSION}"
|
||||||
else
|
else
|
||||||
git commit -m "chore: bump version to ${VERSION}"
|
git commit -m "chore: bump version to ${VERSION}"
|
||||||
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/helexa/cortex.git"
|
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/helexa/cortex.git"
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,4 +4,6 @@
|
|||||||
.idea/
|
.idea/
|
||||||
.vscode/
|
.vscode/
|
||||||
cortex.toml
|
cortex.toml
|
||||||
|
models.toml
|
||||||
doc/plan/*
|
doc/plan/*
|
||||||
|
/target-cuda/
|
||||||
|
|||||||
116
CLAUDE.md
116
CLAUDE.md
@@ -125,7 +125,8 @@ automatically. Clippy warnings must be resolved, not suppressed with
|
|||||||
- One or more GPU nodes running mistral.rs on port 8080
|
- One or more GPU nodes running mistral.rs on port 8080
|
||||||
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
|
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
|
||||||
- Each node runs `mistralrs serve` on port 8080
|
- Each node runs `mistralrs serve` on port 8080
|
||||||
- Gateway listens on port 8000 (API) and 9100 (metrics)
|
- Gateway listens on port 31313 (API) and 31314 (metrics)
|
||||||
|
- neuron listens on port 13131 on each GPU host
|
||||||
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
|
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
|
||||||
|
|
||||||
## Conventions
|
## Conventions
|
||||||
@@ -380,7 +381,7 @@ processes (one process per loaded model, each on its own port).
|
|||||||
|
|
||||||
## neuron API
|
## neuron API
|
||||||
|
|
||||||
neuron exposes an HTTP API on port 9090 that cortex polls and calls.
|
neuron exposes an HTTP API on port 13131 that cortex polls and calls.
|
||||||
|
|
||||||
```
|
```
|
||||||
GET /discovery
|
GET /discovery
|
||||||
@@ -424,8 +425,8 @@ endpoint. cortex.toml shrinks to:
|
|||||||
|
|
||||||
```toml
|
```toml
|
||||||
[gateway]
|
[gateway]
|
||||||
listen = "0.0.0.0:8000"
|
listen = "0.0.0.0:31313"
|
||||||
metrics_listen = "0.0.0.0:9100"
|
metrics_listen = "0.0.0.0:31314"
|
||||||
|
|
||||||
[eviction]
|
[eviction]
|
||||||
strategy = "lru"
|
strategy = "lru"
|
||||||
@@ -433,15 +434,15 @@ defrag_after_cycles = 50
|
|||||||
|
|
||||||
[[neurons]]
|
[[neurons]]
|
||||||
name = "beast"
|
name = "beast"
|
||||||
endpoint = "http://beast.hanzalova.internal:9090"
|
endpoint = "http://beast.hanzalova.internal:13131"
|
||||||
|
|
||||||
[[neurons]]
|
[[neurons]]
|
||||||
name = "benjy"
|
name = "benjy"
|
||||||
endpoint = "http://benjy.kosherinata.internal:9090"
|
endpoint = "http://benjy.hanzalova.internal:13131"
|
||||||
|
|
||||||
[[neurons]]
|
[[neurons]]
|
||||||
name = "quadbrat"
|
name = "quadbrat"
|
||||||
endpoint = "http://quadbrat.hanzalova.internal:9090"
|
endpoint = "http://quadbrat.hanzalova.internal:13131"
|
||||||
```
|
```
|
||||||
|
|
||||||
On startup and periodically, cortex calls `GET /discovery` and
|
On startup and periodically, cortex calls `GET /discovery` and
|
||||||
@@ -521,7 +522,7 @@ cortex/
|
|||||||
│ │ └── metrics.rs # prometheus exporter (unchanged)
|
│ │ └── metrics.rs # prometheus exporter (unchanged)
|
||||||
│ ├── neuron/ # node plane (replaces cortex-agent)
|
│ ├── neuron/ # node plane (replaces cortex-agent)
|
||||||
│ │ └── src/
|
│ │ └── src/
|
||||||
│ │ ├── main.rs # binary entrypoint, axum server on :9090
|
│ │ ├── main.rs # binary entrypoint, axum server on :13131
|
||||||
│ │ ├── discovery.rs # nvidia-smi, device enumeration
|
│ │ ├── discovery.rs # nvidia-smi, device enumeration
|
||||||
│ │ ├── health.rs # runtime GPU polling
|
│ │ ├── health.rs # runtime GPU polling
|
||||||
│ │ ├── api.rs # HTTP handlers for /discovery, /models, etc.
|
│ │ ├── api.rs # HTTP handlers for /discovery, /models, etc.
|
||||||
@@ -595,70 +596,65 @@ placement matching can be added incrementally.
|
|||||||
Completed. Both packages have RPM specs, systemd units, and example configs.
|
Completed. Both packages have RPM specs, systemd units, and example configs.
|
||||||
CI builds parallel SRPMs on tag push and publishes to separate COPR repos.
|
CI builds parallel SRPMs on tag push and publishes to separate COPR repos.
|
||||||
|
|
||||||
- `cortex.spec` → `helexa/cortex` COPR: binary, systemd unit, config files
|
- `cortex.spec` — installs the `cortex` binary. Package name keeps the
|
||||||
- `neuron.spec` → `helexa/neuron` COPR: binary, systemd unit, config
|
short `cortex` because no Fedora package collides with it.
|
||||||
|
- `helexa-neuron.spec` — installs the `neuron` binary under package name
|
||||||
|
`helexa-neuron`. Renamed from bare `neuron` to avoid collision with
|
||||||
|
Fedora's NEURON neural-simulation package
|
||||||
|
(https://src.fedoraproject.org/rpms/neuron); binary, systemd unit,
|
||||||
|
system user, and config dir all stay named `neuron` since those are
|
||||||
|
project-local contexts.
|
||||||
- `data/cortex.service`, `data/neuron.service` — systemd units
|
- `data/cortex.service`, `data/neuron.service` — systemd units
|
||||||
- `cortex.example.toml`, `neuron.example.toml`, `models.example.toml`
|
- `cortex.example.toml`, `neuron.example.toml`, `models.example.toml`
|
||||||
- CI: parallel `srpm-cortex` + `srpm-neuron` jobs, then parallel COPR publish
|
- CI: parallel `srpm-cortex` + `srpm-neuron` jobs, then parallel COPR
|
||||||
|
publish to a single project `helexa/helexa` hosting both packages.
|
||||||
|
|
||||||
Install:
|
Install:
|
||||||
```sh
|
```sh
|
||||||
dnf copr enable helexa/cortex && dnf install cortex # gateway host
|
dnf copr enable helexa/helexa
|
||||||
dnf copr enable helexa/neuron && dnf install neuron # GPU nodes
|
dnf install cortex # gateway host
|
||||||
|
dnf install helexa-neuron # GPU nodes
|
||||||
```
|
```
|
||||||
|
|
||||||
### Phase 11: llama.cpp harness stub
|
## 2026-05-18 addendum: candle-native pivot
|
||||||
|
|
||||||
**Goal:** Prove the harness abstraction works with a second engine.
|
Phases 11 (llama.cpp harness) and 12 (mistral.rs COPR) below are
|
||||||
|
**superseded**. The project no longer treats mistral.rs or llama.cpp as
|
||||||
|
dependencies — both are conceptually out of scope. neuron becomes a
|
||||||
|
candle-native inference daemon, with `Harness` retained as an
|
||||||
|
internal seam for adding future engines (vision/audio/diffusion) but
|
||||||
|
its only implementation being in-process candle.
|
||||||
|
|
||||||
**Steps:**
|
The full staged plan for this pivot lives at
|
||||||
1. `crates/neuron/src/harness/llamacpp.rs` — implement the `Harness`
|
`~/.claude/plans/create-a-more-aggressive-calm-naur.md`. Summary:
|
||||||
trait for llama.cpp's `llama-server`.
|
|
||||||
- `start()` — launch `llama-server` with the correct model path,
|
|
||||||
`--port`, `--n-gpu-layers`, `--tensor-split` args. Track the
|
|
||||||
child process.
|
|
||||||
- `stop()` — send SIGTERM to the child process.
|
|
||||||
- `list_models()` — llama-server serves one model per process, so
|
|
||||||
return a single-element list.
|
|
||||||
- `load_model()` — start a new llama-server process for this model.
|
|
||||||
- `unload_model()` — stop the process.
|
|
||||||
- `inference_endpoint()` — return `http://localhost:{assigned_port}`.
|
|
||||||
2. Port allocation: neuron assigns ports from a range (e.g. 8100-8199)
|
|
||||||
to llama-server instances.
|
|
||||||
3. Register in `HarnessRegistry` when configured:
|
|
||||||
```toml
|
|
||||||
[[harnesses]]
|
|
||||||
name = "llamacpp"
|
|
||||||
binary = "/usr/local/bin/llama-server"
|
|
||||||
port_range = [8100, 8199]
|
|
||||||
```
|
|
||||||
4. Tests: mock llama-server (simple HTTP server returning canned
|
|
||||||
responses), test load/unload/endpoint lifecycle.
|
|
||||||
|
|
||||||
**Done when:** A model with `harness = "llamacpp"` in `models.toml` can
|
- **Stage 1 (this commit):** delete `mistralrs.rs` and `llamacpp.rs`,
|
||||||
be loaded and served through cortex. Tests pass with mock llama-server.
|
scaffold inert `CandleHarness`, drop `endpoint`/`systemd_unit` from
|
||||||
|
`HarnessConfig`, default no-op `start`/`stop` on the `Harness` trait.
|
||||||
|
- **Stages 2–4:** wire up candle model load/unload (quantized Qwen3
|
||||||
|
first), add OpenAI-compatible inference endpoint in neuron, then SSE
|
||||||
|
streaming.
|
||||||
|
- **Stages 5–6:** load-on-activation (default models in config) and
|
||||||
|
unload-on-deactivation (graceful shutdown).
|
||||||
|
- **Stages 7–8:** multi-GPU tensor parallelism and broader model/quant
|
||||||
|
coverage.
|
||||||
|
|
||||||
### Phase 12 (lower priority): mistral.rs COPR packaging
|
Sections of this document that describe mistral.rs HTTP behaviour
|
||||||
|
("mistral.rs API gotchas") are retained as historical context for
|
||||||
|
Phases 1–10 — they document what was true while the project depended
|
||||||
|
on mistral.rs. They do not describe current behaviour.
|
||||||
|
|
||||||
**Goal:** Fedora RPMs for mistral.rs built against specific CUDA versions.
|
---
|
||||||
|
|
||||||
**Steps:**
|
### Phase 11 (superseded): llama.cpp harness stub
|
||||||
1. `mistralrs-cuda.spec` — RPM spec that clones a pinned mistral.rs git
|
|
||||||
tag, builds with `--features cuda`, links against the system CUDA
|
|
||||||
toolkit. Produces `mistralrs-cuda13-server` (CUDA 13.x / sm_120) and
|
|
||||||
`mistralrs-cuda12-server` (CUDA 12.x / sm_89). Install binary to
|
|
||||||
`/usr/local/bin/mistralrs`.
|
|
||||||
2. COPR build config: enable the NVIDIA CUDA repo as a build dependency.
|
|
||||||
Pin the CUDA toolkit version in `BuildRequires`.
|
|
||||||
3. Gitea Actions or manual workflow: bump the mistral.rs tag in the spec,
|
|
||||||
trigger COPR rebuild.
|
|
||||||
4. neuron's mistralrs harness config references which binary/package
|
|
||||||
provides the mistral.rs binary. neuron could warn at startup if the
|
|
||||||
installed mistral.rs CUDA version doesn't match the discovered driver.
|
|
||||||
|
|
||||||
**Done when:** `dnf install mistralrs-cuda13-server` on beast provides a
|
~~Originally planned as a second engine to prove the harness
|
||||||
working `mistralrs` binary built for Blackwell GPUs. `dnf install
|
abstraction.~~ Replaced by the candle harness work in the 2026-05-18
|
||||||
mistralrs-cuda12-server` on benjy provides one built for Ada GPUs.
|
addendum above. llama.cpp's any-model/any-hardware breadth is no
|
||||||
|
longer in scope for helexa.
|
||||||
|
|
||||||
This is a separate repo/spec — not part of the cortex workspace — but
|
### Phase 12 (superseded): mistral.rs COPR packaging
|
||||||
tightly coupled operationally. Track it as a sibling project.
|
|
||||||
|
~~Originally planned to ship CUDA-versioned mistral.rs RPMs.~~ Replaced
|
||||||
|
by the candle harness work in the 2026-05-18 addendum above. With
|
||||||
|
mistral.rs out of the dependency tree, there is nothing to package.
|
||||||
|
|||||||
1616
Cargo.lock
generated
1616
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ members = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.1.2"
|
version = "0.1.16"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
license = "GPL-3.0-or-later"
|
license = "GPL-3.0-or-later"
|
||||||
repository = "https://git.lair.cafe/helexa/cortex"
|
repository = "https://git.lair.cafe/helexa/cortex"
|
||||||
@@ -27,7 +27,7 @@ serde = { version = "1", features = ["derive"] }
|
|||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
|
|
||||||
# http client (for proxying to mistralrs backends)
|
# http client (for proxying to neuron backends)
|
||||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||||
|
|
||||||
# observability
|
# observability
|
||||||
|
|||||||
101
README.md
101
README.md
@@ -1,22 +1,23 @@
|
|||||||
# cortex
|
# cortex
|
||||||
|
|
||||||
A Rust reverse-proxy and fleet management layer for multi-node
|
A Rust reverse-proxy and fleet management layer for multi-node GPU inference
|
||||||
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
|
clusters. Cortex sits in front of one or more `neuron` daemons (each running
|
||||||
|
candle-based inference on a local GPU host) and presents a unified OpenAI +
|
||||||
|
Anthropic compatible API surface.
|
||||||
|
|
||||||
## Problem
|
## Problem
|
||||||
|
|
||||||
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
||||||
model affinities) requires a unified API surface that:
|
model affinities) requires a unified API surface that:
|
||||||
|
|
||||||
- Presents a **single `/v1/models` catalogue** merging every model across every
|
- Presents a **single `/v1/models` catalogue** merging every model that can be
|
||||||
node.
|
served by any neuron in the fleet.
|
||||||
- **Routes requests** to the correct node based on where a model is loaded (or
|
- **Routes requests** to the correct node based on where a model is loaded
|
||||||
*can* be loaded).
|
(or can be loaded), handling cold-load and eviction transparently.
|
||||||
- Manages **model lifecycle** — unload cold models, reload on demand, pin
|
- Manages **model lifecycle** — load on demand, unload cold models, pin
|
||||||
critical ones — using the mistral.rs
|
critical ones — by calling each neuron's `/models/{load,unload}` API.
|
||||||
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
|
|
||||||
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
||||||
every client in the homelab speaks whichever dialect it prefers.
|
every client speaks whichever dialect it prefers.
|
||||||
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
||||||
them as Prometheus counters/histograms.
|
them as Prometheus counters/histograms.
|
||||||
|
|
||||||
@@ -38,10 +39,9 @@ model affinities) requires a unified API surface that:
|
|||||||
└──┬──────┬────────┬──┘
|
└──┬──────┬────────┬──┘
|
||||||
│ │ │
|
│ │ │
|
||||||
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
||||||
│ gpu-large │ │gpu-med │ │ gpu-small │
|
│ neuron │ │ neuron │ │ neuron │
|
||||||
│ mistralrs │ │mistral │ │ mistralrs │
|
│ :13131 │ │ :13131 │ │ :13131 │
|
||||||
│ serve │ │rs serve│ │ serve │
|
│ candle │ │ candle │ │ candle │
|
||||||
│ :8080 │ │ :8080 │ │ :8080 │
|
|
||||||
└───────────┘ └────────┘ └───────────┘
|
└───────────┘ └────────┘ └───────────┘
|
||||||
private network (.internal)
|
private network (.internal)
|
||||||
```
|
```
|
||||||
@@ -50,70 +50,48 @@ model affinities) requires a unified API surface that:
|
|||||||
|
|
||||||
| Crate | Purpose |
|
| Crate | Purpose |
|
||||||
|---|---|
|
|---|---|
|
||||||
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
|
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic envelopes, harness trait, discovery types |
|
||||||
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
|
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, poller, metrics exporter |
|
||||||
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
|
| `neuron` | Per-node daemon: GPU discovery, in-process candle inference, model lifecycle API |
|
||||||
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
||||||
|
|
||||||
## Node setup
|
## Node setup
|
||||||
|
|
||||||
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
|
Each GPU node runs `neuron` (listening on `:13131`). Neuron uses
|
||||||
declared but start **unloaded** — mistral.rs lazy-loads on first request and
|
huggingface/candle for in-process inference — there is no external
|
||||||
the gateway can explicitly unload/reload via the HTTP API.
|
inference subprocess to manage.
|
||||||
|
|
||||||
Example node systemd unit:
|
The neuron RPM (`helexa-neuron`) ships a systemd unit:
|
||||||
|
|
||||||
```ini
|
```sh
|
||||||
# /etc/systemd/system/mistralrs.service
|
dnf copr enable helexa/helexa
|
||||||
[Unit]
|
dnf install helexa-neuron
|
||||||
Description=mistral.rs inference server
|
systemctl enable --now neuron
|
||||||
After=network-online.target
|
|
||||||
Wants=network-online.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
Type=simple
|
|
||||||
ExecStart=/usr/local/bin/mistralrs serve \
|
|
||||||
--from-config /etc/mistralrs/config.toml \
|
|
||||||
--port 8080
|
|
||||||
Restart=on-failure
|
|
||||||
RestartSec=5
|
|
||||||
Environment=CUDA_VISIBLE_DEVICES=0,1
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Gateway config
|
## Gateway config
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
# cortex.toml
|
# /etc/cortex/cortex.toml
|
||||||
[gateway]
|
[gateway]
|
||||||
listen = "0.0.0.0:8000"
|
listen = "0.0.0.0:31313"
|
||||||
metrics_listen = "0.0.0.0:9100"
|
metrics_listen = "0.0.0.0:31314"
|
||||||
|
|
||||||
[eviction]
|
[eviction]
|
||||||
strategy = "lru" # lru | priority
|
strategy = "lru" # lru | priority
|
||||||
defrag_after_cycles = 50
|
defrag_after_cycles = 50
|
||||||
|
|
||||||
[[nodes]]
|
[[neurons]]
|
||||||
name = "gpu-large"
|
name = "beast"
|
||||||
endpoint = "http://gpu-large.internal:8080"
|
endpoint = "http://beast.internal:13131"
|
||||||
vram_mb = 49_152 # e.g. 2x RTX 4090
|
|
||||||
pinned = ["your-org/large-model"]
|
|
||||||
|
|
||||||
[[nodes]]
|
[[neurons]]
|
||||||
name = "gpu-medium"
|
name = "benjy"
|
||||||
endpoint = "http://gpu-medium.internal:8080"
|
endpoint = "http://benjy.internal:13131"
|
||||||
vram_mb = 24_576 # e.g. RTX 4090
|
|
||||||
pinned = ["your-org/medium-model"]
|
|
||||||
|
|
||||||
[[nodes]]
|
|
||||||
name = "gpu-small"
|
|
||||||
endpoint = "http://gpu-small.internal:8080"
|
|
||||||
vram_mb = 12_288 # e.g. RTX 3060
|
|
||||||
pinned = ["your-org/embedding-model"]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Model placement profiles live in `models.toml` — see `models.example.toml`.
|
||||||
|
|
||||||
## Building
|
## Building
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
@@ -131,19 +109,20 @@ cargo clippy --workspace -- -D warnings # warnings are errors
|
|||||||
cargo test --workspace # all tests must pass
|
cargo test --workspace # all tests must pass
|
||||||
```
|
```
|
||||||
|
|
||||||
Tagged releases (`v*`) additionally build an SRPM and publish to COPR.
|
Tagged releases (`v*`) additionally build SRPMs for both `cortex` and
|
||||||
|
`helexa-neuron` and publish to COPR.
|
||||||
|
|
||||||
## Running
|
## Running
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
# start the gateway
|
# start the gateway
|
||||||
cortex serve --config cortex.toml
|
cortex serve --config /etc/cortex/cortex.toml
|
||||||
|
|
||||||
# check fleet status
|
# check fleet status
|
||||||
cortex status
|
cortex status
|
||||||
|
|
||||||
# list all models across nodes
|
# list all models across nodes
|
||||||
curl http://localhost:8000/v1/models
|
curl http://localhost:31313/v1/models
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|||||||
30
asset/manifest.yml
Normal file
30
asset/manifest.yml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Helexa fleet manifest.
|
||||||
|
#
|
||||||
|
# Drives rolling deploys via script/deploy.sh and serves as the source
|
||||||
|
# of truth for which hosts run cortex vs neuron, and which CUDA
|
||||||
|
# compute-capability flavour each neuron host needs.
|
||||||
|
#
|
||||||
|
# Flavour ↔ NVIDIA generation ↔ compute cap:
|
||||||
|
# ampere sm_86 (RTX 30 series — e.g. 3060)
|
||||||
|
# ada sm_89 (RTX 40 series — e.g. 4090)
|
||||||
|
# blackwell sm_120 (RTX 50 series — e.g. 5090)
|
||||||
|
#
|
||||||
|
# The flavour determines which RPM is installed on a given neuron host:
|
||||||
|
# helexa-neuron-<flavour>. Only one flavour may be installed at a time
|
||||||
|
# (the packages Conflict: with each other).
|
||||||
|
|
||||||
|
cortex:
|
||||||
|
host: hanzalova.internal
|
||||||
|
|
||||||
|
neurons:
|
||||||
|
- host: beast.hanzalova.internal
|
||||||
|
flavour: blackwell
|
||||||
|
gpu: "2x RTX 5090"
|
||||||
|
|
||||||
|
- host: benjy.hanzalova.internal
|
||||||
|
flavour: ada
|
||||||
|
gpu: "RTX 4090"
|
||||||
|
|
||||||
|
- host: quadbrat.hanzalova.internal
|
||||||
|
flavour: ampere
|
||||||
|
gpu: "RTX 3060"
|
||||||
@@ -3,22 +3,22 @@
|
|||||||
# Copy to cortex.toml and adjust for your environment.
|
# Copy to cortex.toml and adjust for your environment.
|
||||||
#
|
#
|
||||||
# Environment variable overrides use CORTEX_ prefix with __ separators:
|
# Environment variable overrides use CORTEX_ prefix with __ separators:
|
||||||
# CORTEX_GATEWAY__LISTEN=0.0.0.0:9000
|
# CORTEX_GATEWAY__LISTEN=0.0.0.0:31313
|
||||||
|
|
||||||
[gateway]
|
[gateway]
|
||||||
listen = "0.0.0.0:8000"
|
listen = "0.0.0.0:31313"
|
||||||
metrics_listen = "0.0.0.0:9100"
|
metrics_listen = "0.0.0.0:31314"
|
||||||
|
|
||||||
[eviction]
|
[eviction]
|
||||||
strategy = "lru"
|
strategy = "lru"
|
||||||
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
|
# Restart neurons after this many load/unload cycles to defragment VRAM.
|
||||||
# Set to 0 to disable.
|
# Set to 0 to disable.
|
||||||
defrag_after_cycles = 50
|
defrag_after_cycles = 50
|
||||||
|
|
||||||
# -- Nodes ---------------------------------------------------------------
|
# -- Nodes ---------------------------------------------------------------
|
||||||
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
|
# Each [[nodes]] entry declares a neuron daemon in the fleet.
|
||||||
# Models are discovered by polling the node's /v1/models endpoint.
|
# Models are discovered by polling the neuron's /models endpoint.
|
||||||
# Pinned models are never evicted.
|
# Pinned models (see models.toml) are never evicted.
|
||||||
|
|
||||||
[[nodes]]
|
[[nodes]]
|
||||||
name = "gpu-large"
|
name = "gpu-large"
|
||||||
|
|||||||
66
cortex.spec
66
cortex.spec
@@ -1,5 +1,5 @@
|
|||||||
Name: cortex
|
Name: cortex
|
||||||
Version: 0.1.2
|
Version: 0.1.16
|
||||||
Release: 1%{?dist}
|
Release: 1%{?dist}
|
||||||
Summary: Inference gateway for multi-node GPU clusters
|
Summary: Inference gateway for multi-node GPU clusters
|
||||||
|
|
||||||
@@ -21,12 +21,16 @@ BuildRequires: systemd-rpm-macros
|
|||||||
|
|
||||||
Requires(pre): shadow-utils
|
Requires(pre): shadow-utils
|
||||||
Requires: systemd
|
Requires: systemd
|
||||||
|
Requires: firewalld-filesystem
|
||||||
|
|
||||||
# rpm's sysusers provides-generator only emits versioned user(cortex) when
|
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
|
||||||
# the u-line has GECOS/home/shell fields. %attr(,,cortex) in %files emits
|
# from our .service file and emits Requires: user(cortex)/group(cortex).
|
||||||
# an unversioned Requires: user(cortex), so we provide it explicitly.
|
# rpm's sysusers provides-generator emits the unversioned form for groups
|
||||||
|
# but only a versioned user(cortex) = <base64> for users with GECOS/home/
|
||||||
|
# shell. Provide the unversioned user(cortex) explicitly so dnf can resolve
|
||||||
|
# the auto-generated Requires. Without this, dnf5 silently filters the
|
||||||
|
# package and reports "Nothing to do".
|
||||||
Provides: user(cortex)
|
Provides: user(cortex)
|
||||||
Provides: group(cortex)
|
|
||||||
|
|
||||||
%description
|
%description
|
||||||
Cortex is a Rust reverse-proxy that sits in front of multiple inference
|
Cortex is a Rust reverse-proxy that sits in front of multiple inference
|
||||||
@@ -53,9 +57,10 @@ cargo build --release -p cortex-cli
|
|||||||
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
|
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
|
||||||
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
|
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
|
||||||
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
|
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
|
||||||
install -dm750 %{buildroot}%{_sysconfdir}/cortex
|
install -Dm644 data/cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
install -Dm640 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
|
install -dm755 %{buildroot}%{_sysconfdir}/cortex
|
||||||
install -Dm640 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
|
||||||
|
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
%pre
|
%pre
|
||||||
%sysusers_create_compat %{_builddir}/%{name}-%{version}/data/cortex-sysusers.conf
|
%sysusers_create_compat %{_builddir}/%{name}-%{version}/data/cortex-sysusers.conf
|
||||||
@@ -69,16 +74,53 @@ install -Dm640 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
|||||||
%postun
|
%postun
|
||||||
%systemd_postun_with_restart cortex.service
|
%systemd_postun_with_restart cortex.service
|
||||||
|
|
||||||
|
%posttrans
|
||||||
|
# Migration: older cortex packages shipped the firewalld service as
|
||||||
|
# `helexa-cortex` and (in some build streams) with wrong port numbers
|
||||||
|
# (9301/9302/9304). Operators who enabled that legacy service in their
|
||||||
|
# zone end up with the wrong-port override taking precedence over the
|
||||||
|
# vendor `cortex.xml` now in /usr/lib/firewalld/services/. Clean up the
|
||||||
|
# stale /etc/ override here and migrate any zone bindings to the new
|
||||||
|
# service name.
|
||||||
|
if [ -f /etc/firewalld/services/helexa-cortex.xml ]; then
|
||||||
|
rm -f /etc/firewalld/services/helexa-cortex.xml
|
||||||
|
fi
|
||||||
|
if [ -x /usr/bin/firewall-cmd ] && /usr/bin/firewall-cmd --state >/dev/null 2>&1; then
|
||||||
|
# Drop the legacy service name from every zone where it was enabled
|
||||||
|
# and add the new `cortex` service in its place. Operators who never
|
||||||
|
# ran firewall-cmd against either name see no zone change.
|
||||||
|
for zone in $(/usr/bin/firewall-cmd --get-active-zones 2>/dev/null \
|
||||||
|
| awk '!/^[[:space:]]/ {print $1}'); do
|
||||||
|
if /usr/bin/firewall-cmd --permanent --zone="$zone" --query-service=helexa-cortex >/dev/null 2>&1; then
|
||||||
|
/usr/bin/firewall-cmd --permanent --zone="$zone" --remove-service=helexa-cortex >/dev/null 2>&1 || :
|
||||||
|
/usr/bin/firewall-cmd --permanent --zone="$zone" --add-service=cortex >/dev/null 2>&1 || :
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
/usr/bin/firewall-cmd --reload >/dev/null 2>&1 || :
|
||||||
|
fi
|
||||||
|
:
|
||||||
|
|
||||||
%files
|
%files
|
||||||
%license LICENSE
|
%license LICENSE
|
||||||
%doc README.md
|
%doc README.md
|
||||||
%{_bindir}/cortex
|
%{_bindir}/cortex
|
||||||
%{_unitdir}/cortex.service
|
%{_unitdir}/cortex.service
|
||||||
%{_sysusersdir}/cortex.conf
|
%{_sysusersdir}/cortex.conf
|
||||||
%dir %attr(750,root,cortex) %{_sysconfdir}/cortex
|
%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
%config(noreplace) %attr(640,root,cortex) %{_sysconfdir}/cortex/cortex.toml
|
%dir %{_sysconfdir}/cortex
|
||||||
%config(noreplace) %attr(640,root,cortex) %{_sysconfdir}/cortex/models.toml
|
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
|
||||||
|
%config(noreplace) %{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
%changelog
|
%changelog
|
||||||
* Tue Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
|
||||||
|
- chore: ignore local deploy script
|
||||||
|
- chore: move default ports out of common-collision ranges
|
||||||
|
- ci: drop actions/cache for cargo registry and target
|
||||||
|
|
||||||
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
|
||||||
|
- ci: publish both packages to a single helexa/helexa COPR project
|
||||||
|
- fix(rpm): rename neuron package to helexa-neuron
|
||||||
|
- ci: commit generated %changelog entries back to main
|
||||||
|
|
||||||
|
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
||||||
- Initial package
|
- Initial package
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use tracing_subscriber::EnvFilter;
|
|||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "cortex")]
|
#[command(name = "cortex")]
|
||||||
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
|
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
struct Cli {
|
struct Cli {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
@@ -23,7 +23,7 @@ enum Commands {
|
|||||||
/// Print the fleet status (models, nodes, health).
|
/// Print the fleet status (models, nodes, health).
|
||||||
Status {
|
Status {
|
||||||
/// Gateway API endpoint to query.
|
/// Gateway API endpoint to query.
|
||||||
#[arg(short, long, default_value = "http://localhost:8000")]
|
#[arg(short, long, default_value = "http://localhost:31313")]
|
||||||
endpoint: String,
|
endpoint: String,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
//!
|
//!
|
||||||
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
||||||
//! The gateway accepts these, translates to OpenAI format, proxies to
|
//! The gateway accepts these, translates to OpenAI format, proxies to
|
||||||
//! mistral.rs, then translates the response back.
|
//! the inference backend (neuron), then translates the response back.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
//! Model catalogue — profiles describing how to serve each model.
|
//! Model catalogue — profiles describing how to serve each model.
|
||||||
|
|
||||||
|
use crate::discovery::DeviceInfo;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
@@ -64,4 +65,103 @@ impl ModelCatalogue {
|
|||||||
.iter()
|
.iter()
|
||||||
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
|
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Find a profile by model id.
|
||||||
|
pub fn get(&self, model_id: &str) -> Option<&ModelProfile> {
|
||||||
|
self.models.iter().find(|p| p.id == model_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelProfile {
|
||||||
|
/// True iff this profile's placement constraints can be satisfied
|
||||||
|
/// by the named neuron with the given device topology.
|
||||||
|
///
|
||||||
|
/// Constraints checked:
|
||||||
|
/// - `pinned_on`: non-empty → neuron must be on the list.
|
||||||
|
/// - `min_devices`: neuron must have at least this many devices.
|
||||||
|
/// - `min_device_vram_mb`: at least `min_devices` of the neuron's
|
||||||
|
/// devices must each meet this VRAM floor.
|
||||||
|
pub fn is_feasible_on(&self, neuron_name: &str, devices: &[DeviceInfo]) -> bool {
|
||||||
|
if !self.pinned_on.is_empty() && !self.pinned_on.iter().any(|n| n == neuron_name) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (devices.len() as u32) < self.min_devices {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if let Some(min_vram) = self.min_device_vram_mb {
|
||||||
|
let big_enough = devices
|
||||||
|
.iter()
|
||||||
|
.filter(|d| d.vram_total_mb >= min_vram)
|
||||||
|
.count() as u32;
|
||||||
|
if big_enough < self.min_devices {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::discovery::DeviceInfo;
|
||||||
|
|
||||||
|
fn device(idx: u32, vram_mb: u64) -> DeviceInfo {
|
||||||
|
DeviceInfo {
|
||||||
|
index: idx,
|
||||||
|
name: format!("DEV-{idx}"),
|
||||||
|
vram_total_mb: vram_mb,
|
||||||
|
compute_capability: "8.6".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn profile() -> ModelProfile {
|
||||||
|
ModelProfile {
|
||||||
|
id: "Qwen/Qwen3.6-27B".into(),
|
||||||
|
harness: "candle".into(),
|
||||||
|
quant: None,
|
||||||
|
vram_mb: Some(45_000),
|
||||||
|
min_devices: 2,
|
||||||
|
min_device_vram_mb: Some(24_000),
|
||||||
|
pinned_on: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn feasible_when_two_devices_meet_vram_floor() {
|
||||||
|
let p = profile();
|
||||||
|
let devices = [device(0, 32_000), device(1, 32_000)];
|
||||||
|
assert!(p.is_feasible_on("beast", &devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infeasible_when_only_one_device() {
|
||||||
|
let p = profile();
|
||||||
|
let devices = [device(0, 64_000)];
|
||||||
|
assert!(!p.is_feasible_on("benjy", &devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infeasible_when_one_device_underspec() {
|
||||||
|
let p = profile();
|
||||||
|
let devices = [device(0, 32_000), device(1, 12_000)];
|
||||||
|
assert!(!p.is_feasible_on("mixed", &devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pinned_on_excludes_other_neurons() {
|
||||||
|
let mut p = profile();
|
||||||
|
p.pinned_on = vec!["beast".into()];
|
||||||
|
let devices = [device(0, 32_000), device(1, 32_000)];
|
||||||
|
assert!(p.is_feasible_on("beast", &devices));
|
||||||
|
assert!(!p.is_feasible_on("benjy", &devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_vram_floor_just_needs_min_devices() {
|
||||||
|
let mut p = profile();
|
||||||
|
p.min_device_vram_mb = None;
|
||||||
|
let devices = [device(0, 1_000), device(1, 1_000)];
|
||||||
|
assert!(p.is_feasible_on("anywhere", &devices));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ fn default_models_path() -> String {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct GatewaySettings {
|
pub struct GatewaySettings {
|
||||||
/// Address to listen on for API requests (e.g. "0.0.0.0:8000")
|
/// Address to listen on for API requests (e.g. "0.0.0.0:31313")
|
||||||
pub listen: String,
|
pub listen: String,
|
||||||
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:9100")
|
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:31314")
|
||||||
pub metrics_listen: String,
|
pub metrics_listen: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ pub enum EvictionStrategy {
|
|||||||
pub struct NeuronEndpoint {
|
pub struct NeuronEndpoint {
|
||||||
/// Human-readable node name (e.g. "beast")
|
/// Human-readable node name (e.g. "beast")
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090")
|
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131")
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,8 +70,8 @@ impl Default for GatewayConfig {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
gateway: GatewaySettings {
|
gateway: GatewaySettings {
|
||||||
listen: "0.0.0.0:8000".into(),
|
listen: "0.0.0.0:31313".into(),
|
||||||
metrics_listen: "0.0.0.0:9100".into(),
|
metrics_listen: "0.0.0.0:31314".into(),
|
||||||
},
|
},
|
||||||
eviction: EvictionSettings {
|
eviction: EvictionSettings {
|
||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ use async_trait::async_trait;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Configuration for a harness instance on a neuron.
|
/// Configuration for a harness instance on a neuron.
|
||||||
|
///
|
||||||
|
/// All current harnesses are in-process (candle); per-harness tuning
|
||||||
|
/// (cache paths, device policies, etc.) lives in dedicated config
|
||||||
|
/// blocks rather than on this struct.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct HarnessConfig {
|
pub struct HarnessConfig {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// Base URL of the harness (e.g. "http://localhost:8080" for mistral.rs).
|
|
||||||
pub endpoint: Option<String>,
|
|
||||||
/// Systemd unit name, if the harness is managed via systemd.
|
|
||||||
pub systemd_unit: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Health status of a harness process.
|
/// Health status of a harness process.
|
||||||
@@ -47,16 +47,24 @@ pub struct ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// What an inference harness must do, from neuron's perspective.
|
/// What an inference harness must do, from neuron's perspective.
|
||||||
|
///
|
||||||
|
/// All current harnesses are in-process — they share neuron's address
|
||||||
|
/// space and lifecycle. `start`/`stop` therefore default to no-ops; a
|
||||||
|
/// future process-supervising harness would override them.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Harness: Send + Sync {
|
pub trait Harness: Send + Sync {
|
||||||
/// Human-readable name (e.g. "mistralrs", "llamacpp", "comfyui").
|
/// Human-readable name (e.g. "candle").
|
||||||
fn name(&self) -> &str;
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
/// Start the harness process if it is not already running.
|
/// Start the harness. Default no-op for in-process harnesses.
|
||||||
async fn start(&self, config: &HarnessConfig) -> Result<()>;
|
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Stop the harness process gracefully.
|
/// Stop the harness. Default no-op for in-process harnesses.
|
||||||
async fn stop(&self) -> Result<()>;
|
async fn stop(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Health check. Returns the harness process status.
|
/// Health check. Returns the harness process status.
|
||||||
async fn health(&self) -> HarnessHealth;
|
async fn health(&self) -> HarnessHealth;
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::discovery::DiscoveryResponse;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -6,13 +7,19 @@ use std::collections::HashMap;
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct NodeState {
|
pub struct NodeState {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090").
|
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131").
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
pub healthy: bool,
|
pub healthy: bool,
|
||||||
pub models: HashMap<String, ModelEntry>,
|
pub models: HashMap<String, ModelEntry>,
|
||||||
/// Number of load/unload cycles since last process restart.
|
/// Number of load/unload cycles since last process restart.
|
||||||
pub lifecycle_cycles: u32,
|
pub lifecycle_cycles: u32,
|
||||||
pub last_poll: Option<DateTime<Utc>>,
|
pub last_poll: Option<DateTime<Utc>>,
|
||||||
|
/// Result of the most recent successful `GET /discovery` against
|
||||||
|
/// this neuron. Cached forever once obtained — device topology is
|
||||||
|
/// invariant for a given neuron process. `None` until the first
|
||||||
|
/// successful poll. Used by the router and `/v1/models` to do
|
||||||
|
/// catalogue × topology feasibility checks.
|
||||||
|
pub discovery: Option<DiscoveryResponse>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A model registered on a node, with its runtime status.
|
/// A model registered on a node, with its runtime status.
|
||||||
@@ -36,12 +43,32 @@ pub enum ModelStatus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
|
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
|
||||||
/// Includes which node(s) host this model and their status.
|
///
|
||||||
|
/// The first four fields (`id`, `object`, `created`, `owned_by`) match
|
||||||
|
/// OpenAI's `/v1/models` shape verbatim, so existing OpenAI-aware
|
||||||
|
/// tooling deserialises this without custom code. The remaining fields
|
||||||
|
/// are helexa-specific extensions — OpenAI clients ignore unknown
|
||||||
|
/// fields and other consumers can read them for placement / debugging.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct CortexModelEntry {
|
pub struct CortexModelEntry {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
/// Always `"model"` per OpenAI's contract.
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Which nodes have this model (and their status).
|
/// Unix-second timestamp; cortex stamps this at response time.
|
||||||
|
pub created: u64,
|
||||||
|
/// OpenAI's "publisher" field — `"helexa"` for everything we serve.
|
||||||
|
pub owned_by: String,
|
||||||
|
/// True if any neuron currently has this model loaded. False for
|
||||||
|
/// catalogue entries that are feasible but not yet loaded.
|
||||||
|
pub loaded: bool,
|
||||||
|
/// Neurons whose discovered topology can satisfy this model's
|
||||||
|
/// catalogue placement constraints. Empty for models that are
|
||||||
|
/// loaded somewhere but not present in the catalogue (cortex has
|
||||||
|
/// no feasibility opinion on those).
|
||||||
|
pub feasible_on: Vec<String>,
|
||||||
|
/// Where this model is actually loaded right now. Subset of (or
|
||||||
|
/// disjoint from) `feasible_on` depending on whether the catalogue
|
||||||
|
/// covers this model.
|
||||||
pub locations: Vec<ModelLocation>,
|
pub locations: Vec<ModelLocation>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
||||||
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
||||||
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
||||||
//! extension field mistral.rs supports.
|
//! extension field a backend might support.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -22,7 +22,7 @@ pub struct ChatCompletionRequest {
|
|||||||
pub max_tokens: Option<u64>,
|
pub max_tokens: Option<u64>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stream: Option<bool>,
|
pub stream: Option<bool>,
|
||||||
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
|
/// All other fields (tools, response_format, backend extensions, etc.)
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub extra: Value,
|
pub extra: Value,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ tokio-stream.workspace = true
|
|||||||
eventsource-stream.workspace = true
|
eventsource-stream.workspace = true
|
||||||
bytes = "1"
|
bytes = "1"
|
||||||
urlencoding = "2"
|
urlencoding = "2"
|
||||||
|
url = "2"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { workspace = true, features = ["test-util"] }
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
|
|||||||
@@ -34,12 +34,30 @@ async fn chat_completions(
|
|||||||
) -> Response {
|
) -> Response {
|
||||||
let model_id = match extract_model(&body) {
|
let model_id = match extract_model(&body) {
|
||||||
Some(m) => m,
|
Some(m) => m,
|
||||||
None => return error_response(400, "missing 'model' field in request body"),
|
None => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "chat_completions",
|
||||||
|
"rejected: missing 'model' field in request body"
|
||||||
|
);
|
||||||
|
return error_response(400, "missing 'model' field in request body");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let route = match router::resolve(&fleet, &model_id).await {
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => return error_response(404, &e.to_string()),
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "chat_completions",
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"route resolve failed"
|
||||||
|
);
|
||||||
|
// RouteError's Display strings are short and informative
|
||||||
|
// ("model 'X' not found...", "no healthy nodes available")
|
||||||
|
// — fine to surface to the caller. The warn above carries
|
||||||
|
// any extra context for operators.
|
||||||
|
return error_response(404, &e.to_string());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
touch_model(&fleet, &route.node_name, &model_id).await;
|
||||||
@@ -63,12 +81,30 @@ async fn completions(
|
|||||||
) -> Response {
|
) -> Response {
|
||||||
let model_id = match extract_model(&body) {
|
let model_id = match extract_model(&body) {
|
||||||
Some(m) => m,
|
Some(m) => m,
|
||||||
None => return error_response(400, "missing 'model' field in request body"),
|
None => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "completions",
|
||||||
|
"rejected: missing 'model' field in request body"
|
||||||
|
);
|
||||||
|
return error_response(400, "missing 'model' field in request body");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let route = match router::resolve(&fleet, &model_id).await {
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => return error_response(404, &e.to_string()),
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "completions",
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"route resolve failed"
|
||||||
|
);
|
||||||
|
// RouteError's Display strings are short and informative
|
||||||
|
// ("model 'X' not found...", "no healthy nodes available")
|
||||||
|
// — fine to surface to the caller. The warn above carries
|
||||||
|
// any extra context for operators.
|
||||||
|
return error_response(404, &e.to_string());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
touch_model(&fleet, &route.node_name, &model_id).await;
|
||||||
@@ -85,7 +121,14 @@ async fn anthropic_messages(
|
|||||||
// Parse as Anthropic request.
|
// Parse as Anthropic request.
|
||||||
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
|
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
error = %e,
|
||||||
|
"rejected: invalid Anthropic request body"
|
||||||
|
);
|
||||||
|
return error_response(400, "invalid Anthropic request body");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let model_id = anth_req.model.clone();
|
let model_id = anth_req.model.clone();
|
||||||
@@ -95,12 +138,32 @@ async fn anthropic_messages(
|
|||||||
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
|
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
|
||||||
let openai_body = match serde_json::to_vec(&openai_req) {
|
let openai_body = match serde_json::to_vec(&openai_req) {
|
||||||
Ok(b) => Bytes::from(b),
|
Ok(b) => Bytes::from(b),
|
||||||
Err(e) => return error_response(500, &format!("translation error: {e}")),
|
Err(e) => {
|
||||||
|
tracing::error!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"internal: failed to serialise translated OpenAI request"
|
||||||
|
);
|
||||||
|
return error_response(500, "internal translation error");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let route = match router::resolve(&fleet, &model_id).await {
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => return error_response(404, &e.to_string()),
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"route resolve failed"
|
||||||
|
);
|
||||||
|
// RouteError's Display strings are short and informative
|
||||||
|
// ("model 'X' not found...", "no healthy nodes available")
|
||||||
|
// — fine to surface to the caller. The warn above carries
|
||||||
|
// any extra context for operators.
|
||||||
|
return error_response(404, &e.to_string());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
touch_model(&fleet, &route.node_name, &model_id).await;
|
||||||
@@ -133,14 +196,25 @@ async fn anthropic_messages(
|
|||||||
Ok(resp) => resp,
|
Ok(resp) => resp,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
|
// forward_request already warn'd with the wire-level
|
||||||
|
// detail; no need to log again here.
|
||||||
e.into_response()
|
e.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Non-streaming: proxy, buffer full response, translate back to Anthropic.
|
// Non-streaming: proxy, buffer full response, translate back to Anthropic.
|
||||||
|
let target_url = format!("{}/v1/chat/completions", route.endpoint);
|
||||||
|
tracing::info!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
cold_start = route.cold_start,
|
||||||
|
"proxying request"
|
||||||
|
);
|
||||||
let upstream_resp = fleet
|
let upstream_resp = fleet
|
||||||
.http_client
|
.http_client
|
||||||
.post(format!("{}/v1/chat/completions", route.endpoint))
|
.post(&target_url)
|
||||||
.body(openai_body)
|
.body(openai_body)
|
||||||
.header("content-type", "application/json")
|
.header("content-type", "application/json")
|
||||||
.send()
|
.send()
|
||||||
@@ -150,22 +224,49 @@ async fn anthropic_messages(
|
|||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
return error_response(502, &format!("upstream request failed: {e}"));
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
error = %e,
|
||||||
|
"upstream request failed (network)"
|
||||||
|
);
|
||||||
|
return error_response(502, "upstream request failed");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if !upstream_resp.status().is_success() {
|
let upstream_status = upstream_resp.status();
|
||||||
|
if !upstream_status.is_success() {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
let status = upstream_resp.status().as_u16();
|
let status = upstream_status.as_u16();
|
||||||
let body = upstream_resp.text().await.unwrap_or_default();
|
let body = upstream_resp.text().await.unwrap_or_default();
|
||||||
return error_response(status, &format!("upstream error: {body}"));
|
let body_snippet = body.chars().take(512).collect::<String>();
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
status,
|
||||||
|
body = %body_snippet,
|
||||||
|
"upstream returned non-2xx"
|
||||||
|
);
|
||||||
|
return error_response(status, &format!("upstream returned {status}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
let body_bytes = match upstream_resp.bytes().await {
|
let body_bytes = match upstream_resp.bytes().await {
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
return error_response(502, &format!("failed to read upstream response: {e}"));
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
error = %e,
|
||||||
|
"failed to read upstream response body"
|
||||||
|
);
|
||||||
|
return error_response(502, "failed to read upstream response");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -174,7 +275,20 @@ async fn anthropic_messages(
|
|||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
return error_response(502, &format!("failed to parse upstream response: {e}"));
|
let body_snippet = String::from_utf8_lossy(&body_bytes)
|
||||||
|
.chars()
|
||||||
|
.take(512)
|
||||||
|
.collect::<String>();
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
error = %e,
|
||||||
|
body = %body_snippet,
|
||||||
|
"failed to parse upstream response as OpenAI ChatCompletionResponse"
|
||||||
|
);
|
||||||
|
return error_response(502, "malformed upstream response");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -185,12 +299,62 @@ async fn anthropic_messages(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `GET /v1/models` — aggregate models from all nodes.
|
/// `GET /v1/models` — union of (catalogue × topology feasibility) and
|
||||||
|
/// (currently loaded somewhere). The result is what the fleet *could*
|
||||||
|
/// serve, not just what's already loaded — so OpenAI-compatible tools
|
||||||
|
/// see every model the operator has provisioned, and cortex
|
||||||
|
/// transparently cold-loads the first time one is requested.
|
||||||
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
let now = Utc::now().timestamp() as u64;
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
|
let catalogue = &fleet.catalogue;
|
||||||
std::collections::HashMap::new();
|
|
||||||
|
|
||||||
|
let mut entries: HashMap<String, CortexModelEntry> = HashMap::new();
|
||||||
|
|
||||||
|
// Pass 1: catalogue × topology. For every catalogue profile, find
|
||||||
|
// healthy neurons whose discovered devices satisfy the profile.
|
||||||
|
// Catalogue-defined models surface here even if nothing has loaded
|
||||||
|
// them yet — that's the point of the unified endpoint.
|
||||||
|
for profile in &catalogue.models {
|
||||||
|
let mut feasible_on = Vec::new();
|
||||||
|
for node in nodes.values() {
|
||||||
|
if !node.healthy {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let Some(disc) = node.discovery.as_ref() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if profile.is_feasible_on(&node.name, &disc.devices) {
|
||||||
|
feasible_on.push(node.name.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if feasible_on.is_empty() {
|
||||||
|
// The catalogue lists this model but no neuron's topology
|
||||||
|
// matches — surface it as not-loaded with no feasible
|
||||||
|
// location. Hides nothing; lets operators see why a
|
||||||
|
// configured model isn't reachable.
|
||||||
|
feasible_on.clear();
|
||||||
|
}
|
||||||
|
entries.insert(
|
||||||
|
profile.id.clone(),
|
||||||
|
CortexModelEntry {
|
||||||
|
id: profile.id.clone(),
|
||||||
|
object: "model".into(),
|
||||||
|
created: now,
|
||||||
|
owned_by: "helexa".into(),
|
||||||
|
loaded: false,
|
||||||
|
feasible_on,
|
||||||
|
locations: Vec::new(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass 2: layer the actually-loaded state on top. For each
|
||||||
|
// (node, model) entry, attach a ModelLocation. If the model isn't
|
||||||
|
// in the catalogue, create a new CortexModelEntry from scratch —
|
||||||
|
// cortex doesn't refuse to surface a manually-loaded model just
|
||||||
|
// because the operator didn't enumerate it in models.toml.
|
||||||
for node in nodes.values() {
|
for node in nodes.values() {
|
||||||
for (model_id, entry) in &node.models {
|
for (model_id, entry) in &node.models {
|
||||||
let location = ModelLocation {
|
let location = ModelLocation {
|
||||||
@@ -198,19 +362,30 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
|||||||
status: entry.status,
|
status: entry.status,
|
||||||
vram_estimate_mb: entry.vram_estimate_mb,
|
vram_estimate_mb: entry.vram_estimate_mb,
|
||||||
};
|
};
|
||||||
model_map
|
let was_loaded = matches!(entry.status, cortex_core::node::ModelStatus::Loaded);
|
||||||
|
entries
|
||||||
.entry(model_id.clone())
|
.entry(model_id.clone())
|
||||||
.and_modify(|e| e.locations.push(location.clone()))
|
.and_modify(|e| {
|
||||||
|
e.locations.push(location.clone());
|
||||||
|
if was_loaded {
|
||||||
|
e.loaded = true;
|
||||||
|
}
|
||||||
|
})
|
||||||
.or_insert_with(|| CortexModelEntry {
|
.or_insert_with(|| CortexModelEntry {
|
||||||
id: model_id.clone(),
|
id: model_id.clone(),
|
||||||
object: "model".into(),
|
object: "model".into(),
|
||||||
|
created: now,
|
||||||
|
owned_by: "helexa".into(),
|
||||||
|
loaded: was_loaded,
|
||||||
|
// Not in catalogue — cortex has no opinion on
|
||||||
|
// feasibility; leave empty.
|
||||||
|
feasible_on: Vec::new(),
|
||||||
locations: vec![location],
|
locations: vec![location],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let data: Vec<Value> = model_map.values().map(|e| json!(e)).collect();
|
let data: Vec<Value> = entries.values().map(|e| json!(e)).collect();
|
||||||
|
|
||||||
Json(json!({
|
Json(json!({
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": data,
|
"data": data,
|
||||||
@@ -265,6 +440,9 @@ async fn proxy_with_metrics(
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
|
// proxy::forward_request already warn'd with wire-level
|
||||||
|
// detail (target URL, error, status). ProxyError::into_response
|
||||||
|
// now returns a generic message — no body leak.
|
||||||
e.into_response()
|
e.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
use crate::state::CortexState;
|
use crate::state::CortexState;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
|
use cortex_core::discovery::DiscoveryResponse;
|
||||||
use cortex_core::harness::ModelInfo;
|
use cortex_core::harness::ModelInfo;
|
||||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -25,7 +26,59 @@ pub async fn poll_once(fleet: &CortexState) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// One-shot fetch of `GET /discovery`. Cached on the NodeState forever
|
||||||
|
/// after the first success — topology is invariant for a given neuron
|
||||||
|
/// process. Skipped when the cache is already populated.
|
||||||
|
async fn maybe_poll_discovery(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||||
|
{
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
match nodes.get(name) {
|
||||||
|
Some(n) if n.discovery.is_some() => return,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let url = format!("{endpoint}/discovery");
|
||||||
|
let resp = match fleet
|
||||||
|
.http_client
|
||||||
|
.get(&url)
|
||||||
|
.timeout(Duration::from_secs(5))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) if r.status().is_success() => r,
|
||||||
|
Ok(r) => {
|
||||||
|
tracing::debug!(node = name, status = %r.status(), "discovery probe non-success");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(node = name, error = %e, "discovery probe unreachable");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match resp.json::<DiscoveryResponse>().await {
|
||||||
|
Ok(d) => {
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
if let Some(node) = nodes.get_mut(name) {
|
||||||
|
tracing::info!(
|
||||||
|
node = name,
|
||||||
|
hostname = %d.hostname,
|
||||||
|
devices = d.devices.len(),
|
||||||
|
"discovery cached"
|
||||||
|
);
|
||||||
|
node.discovery = Some(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(node = name, error = %e, "failed to parse /discovery response");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||||
|
// Topology first — cheap once cached, and the router needs it to
|
||||||
|
// route requests against catalogue entries that aren't loaded yet.
|
||||||
|
maybe_poll_discovery(fleet, name, endpoint).await;
|
||||||
|
|
||||||
let url = format!("{endpoint}/models");
|
let url = format!("{endpoint}/models");
|
||||||
|
|
||||||
let result = fleet
|
let result = fleet
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//! Streaming HTTP reverse proxy to mistral.rs backends.
|
//! Streaming HTTP reverse proxy to neuron backends.
|
||||||
//!
|
//!
|
||||||
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
||||||
//! The proxy captures timing information for metrics but does not
|
//! The proxy captures timing information for metrics but does not
|
||||||
@@ -12,6 +12,13 @@ use axum::response::{IntoResponse, Response};
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
|
|
||||||
/// Proxy a request body to the resolved backend node and stream the response.
|
/// Proxy a request body to the resolved backend node and stream the response.
|
||||||
|
///
|
||||||
|
/// Logging contract: every call emits exactly one structured event at
|
||||||
|
/// info / warn level for operator visibility, regardless of outcome.
|
||||||
|
/// Network-level failures and non-2xx upstream statuses are warn'd here
|
||||||
|
/// (closest to the wire); the user-facing response carries only the
|
||||||
|
/// status code and a generic message — implementation detail (body,
|
||||||
|
/// error chain) lives in the log, never in the API surface.
|
||||||
pub async fn forward_request(
|
pub async fn forward_request(
|
||||||
client: &Client,
|
client: &Client,
|
||||||
route: &RouteDecision,
|
route: &RouteDecision,
|
||||||
@@ -37,10 +44,33 @@ pub async fn forward_request(
|
|||||||
req_builder = req_builder.header(key, value);
|
req_builder = req_builder.header(key, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
|
let upstream_resp = match req_builder.send().await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %url,
|
||||||
|
error = %e,
|
||||||
|
"proxy: upstream request failed (network)"
|
||||||
|
);
|
||||||
|
return Err(ProxyError::Upstream(e));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let status =
|
let upstream_status = upstream_resp.status();
|
||||||
StatusCode::from_u16(upstream_resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
if !upstream_status.is_success() {
|
||||||
|
// Streaming body — can't snippet without breaking the stream
|
||||||
|
// pass-through. Log status + URL; the client still gets the
|
||||||
|
// upstream status, just without the leaked body.
|
||||||
|
tracing::warn!(
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %url,
|
||||||
|
status = upstream_status.as_u16(),
|
||||||
|
"proxy: upstream returned non-2xx"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let status = StatusCode::from_u16(upstream_status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||||
|
|
||||||
let resp_headers = upstream_resp.headers().clone();
|
let resp_headers = upstream_resp.headers().clone();
|
||||||
let stream = upstream_resp.bytes_stream();
|
let stream = upstream_resp.bytes_stream();
|
||||||
@@ -52,28 +82,37 @@ pub async fn forward_request(
|
|||||||
response = response.header(key, value);
|
response = response.header(key, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
response
|
response.body(body).map_err(|e| {
|
||||||
.body(body)
|
tracing::warn!(
|
||||||
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
|
node = %route.node_name,
|
||||||
|
url = %url,
|
||||||
|
error = %e,
|
||||||
|
"proxy: failed to build response"
|
||||||
|
);
|
||||||
|
ProxyError::ResponseBuild(e.to_string())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum ProxyError {
|
pub enum ProxyError {
|
||||||
#[error("upstream request failed: {0}")]
|
#[error("upstream request failed")]
|
||||||
Upstream(reqwest::Error),
|
Upstream(reqwest::Error),
|
||||||
#[error("failed to build response: {0}")]
|
#[error("failed to build response")]
|
||||||
ResponseBuild(String),
|
ResponseBuild(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for ProxyError {
|
impl IntoResponse for ProxyError {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
let status = match &self {
|
let (status, message) = match &self {
|
||||||
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
|
ProxyError::Upstream(_) => (StatusCode::BAD_GATEWAY, "upstream request failed"),
|
||||||
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
ProxyError::ResponseBuild(_) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"failed to build response",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let body = serde_json::json!({
|
let body = serde_json::json!({
|
||||||
"error": {
|
"error": {
|
||||||
"message": self.to_string(),
|
"message": message,
|
||||||
"type": "proxy_error",
|
"type": "proxy_error",
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -2,13 +2,21 @@
|
|||||||
//!
|
//!
|
||||||
//! Given a model ID from an inbound request, determine which node should
|
//! Given a model ID from an inbound request, determine which node should
|
||||||
//! handle it. Priority:
|
//! handle it. Priority:
|
||||||
//! 1. Node where the model is currently `Loaded`
|
//! 1. Node where the model is currently `Loaded` → use it.
|
||||||
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
|
//! 2. Node where the model is `Unloaded` → use it; neuron's existing
|
||||||
//! 3. Error: model not found on any node
|
//! lazy-load behaviour will reload before serving the request.
|
||||||
|
//! 3. Model is in the catalogue → pick a feasible neuron, call
|
||||||
|
//! `POST /models/load`, wait for the load to complete, then
|
||||||
|
//! proxy. First-request cold-load latency is acceptable per the
|
||||||
|
//! unified-endpoint contract.
|
||||||
|
//! 4. Not in catalogue, not loaded anywhere → 404.
|
||||||
|
|
||||||
use crate::state::CortexState;
|
use crate::state::CortexState;
|
||||||
|
use cortex_core::catalogue::ModelProfile;
|
||||||
|
use cortex_core::harness::ModelSpec;
|
||||||
use cortex_core::node::ModelStatus;
|
use cortex_core::node::ModelStatus;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
/// The routing decision: which node endpoint to proxy the request to.
|
/// The routing decision: which node endpoint to proxy the request to.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -16,18 +24,31 @@ pub struct RouteDecision {
|
|||||||
pub node_name: String,
|
pub node_name: String,
|
||||||
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
|
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
/// Whether the model will need to load (cold start).
|
/// Whether the model will need to load (cold start). Set to true
|
||||||
|
/// when we proxied to an `Unloaded` node (lazy load on neuron) or
|
||||||
|
/// when we just triggered an explicit cold-load via the catalogue
|
||||||
|
/// path.
|
||||||
pub cold_start: bool,
|
pub cold_start: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum RouteError {
|
pub enum RouteError {
|
||||||
#[error("model '{0}' not found on any node")]
|
#[error("model '{0}' not found on any node and not in catalogue")]
|
||||||
ModelNotFound(String),
|
ModelNotFound(String),
|
||||||
#[error("no healthy nodes available")]
|
#[error("no healthy nodes available")]
|
||||||
NoHealthyNodes,
|
NoHealthyNodes,
|
||||||
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
|
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
|
||||||
EndpointResolveFailed(String, String),
|
EndpointResolveFailed(String, String),
|
||||||
|
#[error(
|
||||||
|
"model '{model_id}' is in the catalogue but no healthy neuron's topology satisfies its constraints"
|
||||||
|
)]
|
||||||
|
NoFeasibleNeuron { model_id: String },
|
||||||
|
#[error("cold-load of '{model_id}' on '{node}' failed: {message}")]
|
||||||
|
ColdLoadFailed {
|
||||||
|
model_id: String,
|
||||||
|
node: String,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve which node should serve a request for the given model.
|
/// Resolve which node should serve a request for the given model.
|
||||||
@@ -36,42 +57,231 @@ pub async fn resolve(
|
|||||||
fleet: &Arc<CortexState>,
|
fleet: &Arc<CortexState>,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
) -> Result<RouteDecision, RouteError> {
|
) -> Result<RouteDecision, RouteError> {
|
||||||
let (node_name, neuron_endpoint, cold_start) = {
|
// Snapshot loaded / unloaded state from the poller cache.
|
||||||
|
let (loaded_route, unloaded_route, any_healthy) = {
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let mut loaded_route = None;
|
||||||
let mut loaded_candidate = None;
|
let mut unloaded_route = None;
|
||||||
let mut unloaded_candidate = None;
|
let mut any_healthy = false;
|
||||||
|
|
||||||
for node in nodes.values() {
|
for node in nodes.values() {
|
||||||
if !node.healthy {
|
if !node.healthy {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
any_healthy = true;
|
||||||
if let Some(entry) = node.models.get(model_id) {
|
if let Some(entry) = node.models.get(model_id) {
|
||||||
match entry.status {
|
match entry.status {
|
||||||
ModelStatus::Loaded | ModelStatus::Reloading => {
|
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||||
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false));
|
loaded_route = Some((node.name.clone(), node.endpoint.clone(), false));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
ModelStatus::Unloaded => {
|
ModelStatus::Unloaded => {
|
||||||
if unloaded_candidate.is_none() {
|
if unloaded_route.is_none() {
|
||||||
unloaded_candidate =
|
unloaded_route = Some((node.name.clone(), node.endpoint.clone(), true));
|
||||||
Some((node.name.clone(), node.endpoint.clone(), true));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
(loaded_route, unloaded_route, any_healthy)
|
||||||
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
|
|
||||||
if nodes.values().any(|n| n.healthy) {
|
|
||||||
RouteError::ModelNotFound(model_id.to_string())
|
|
||||||
} else {
|
|
||||||
RouteError::NoHealthyNodes
|
|
||||||
}
|
|
||||||
})?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Ask the neuron for the inference endpoint for this model.
|
if !any_healthy {
|
||||||
|
return Err(RouteError::NoHealthyNodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 1: already loaded.
|
||||||
|
if let Some((node_name, neuron_endpoint, cold_start)) = loaded_route {
|
||||||
|
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 2: known to neuron but unloaded (neuron's lazy load).
|
||||||
|
if let Some((node_name, neuron_endpoint, cold_start)) = unloaded_route {
|
||||||
|
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 3: catalogue × topology cold-load.
|
||||||
|
if let Some(profile) = fleet.catalogue.get(model_id) {
|
||||||
|
let (node_name, neuron_endpoint) = pick_feasible_neuron(fleet, profile).await?;
|
||||||
|
cold_load(fleet, &node_name, &neuron_endpoint, profile).await?;
|
||||||
|
return finish(fleet, &node_name, &neuron_endpoint, model_id, true).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(RouteError::ModelNotFound(model_id.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pick a healthy neuron whose discovered topology satisfies the
|
||||||
|
/// profile. Preference order:
|
||||||
|
/// 1. A neuron from `profile.pinned_on` that is healthy + feasible.
|
||||||
|
/// 2. Otherwise, any healthy + feasible neuron, stable by name.
|
||||||
|
async fn pick_feasible_neuron(
|
||||||
|
fleet: &Arc<CortexState>,
|
||||||
|
profile: &ModelProfile,
|
||||||
|
) -> Result<(String, String), RouteError> {
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let mut candidates: Vec<(String, String, bool)> = Vec::new();
|
||||||
|
for node in nodes.values() {
|
||||||
|
if !node.healthy {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let Some(disc) = node.discovery.as_ref() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if !profile.is_feasible_on(&node.name, &disc.devices) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let pinned = profile.pinned_on.iter().any(|n| n == &node.name);
|
||||||
|
candidates.push((node.name.clone(), node.endpoint.clone(), pinned));
|
||||||
|
}
|
||||||
|
candidates.sort_by(|a, b| {
|
||||||
|
b.2.cmp(&a.2) // pinned first (true > false)
|
||||||
|
.then(a.0.cmp(&b.0))
|
||||||
|
});
|
||||||
|
let pick = candidates.into_iter().next();
|
||||||
|
pick.map(|(n, e, _)| (n, e))
|
||||||
|
.ok_or_else(|| RouteError::NoFeasibleNeuron {
|
||||||
|
model_id: profile.id.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Issue `POST {endpoint}/models/load` for this profile on this neuron,
|
||||||
|
/// blocking until the load completes (neuron's load endpoint is
|
||||||
|
/// synchronous — it returns 200 once VRAM is materialised). On success
|
||||||
|
/// also inserts a `Loaded` entry into the local NodeState cache so the
|
||||||
|
/// caller's subsequent endpoint lookup sees the new model without
|
||||||
|
/// waiting for the next poll cycle.
|
||||||
|
async fn cold_load(
|
||||||
|
fleet: &Arc<CortexState>,
|
||||||
|
node_name: &str,
|
||||||
|
neuron_endpoint: &str,
|
||||||
|
profile: &ModelProfile,
|
||||||
|
) -> Result<(), RouteError> {
|
||||||
|
let spec = profile_to_spec(fleet, node_name, profile).await;
|
||||||
|
let url = format!("{neuron_endpoint}/models/load");
|
||||||
|
tracing::info!(model = %profile.id, node = node_name, "cold-loading via /models/load");
|
||||||
|
|
||||||
|
// Generous timeout: a fresh download + safetensors mmap + device
|
||||||
|
// copy for a 30B-class dense model can comfortably exceed 5 min on
|
||||||
|
// a slow link. The HTTP client's own default already covers most
|
||||||
|
// of this; pin a longer per-request bound just here.
|
||||||
|
let resp = match fleet
|
||||||
|
.http_client
|
||||||
|
.post(&url)
|
||||||
|
.timeout(Duration::from_secs(1800))
|
||||||
|
.json(&spec)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
return Err(RouteError::ColdLoadFailed {
|
||||||
|
model_id: profile.id.clone(),
|
||||||
|
node: node_name.to_string(),
|
||||||
|
message: format!("HTTP request failed: {e}"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
// Neuron returns 400 "already loaded" when two concurrent
|
||||||
|
// requests race the same model. Treat that as success — both
|
||||||
|
// requests effectively achieved the same end state.
|
||||||
|
if body.contains("already loaded") {
|
||||||
|
tracing::info!(
|
||||||
|
model = %profile.id,
|
||||||
|
node = node_name,
|
||||||
|
"cold-load saw 'already loaded' — treating as success"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return Err(RouteError::ColdLoadFailed {
|
||||||
|
model_id: profile.id.clone(),
|
||||||
|
node: node_name.to_string(),
|
||||||
|
message: format!("HTTP {status}: {body}"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::info!(model = %profile.id, node = node_name, "cold-load returned 200");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warm the cache: insert a Loaded ModelEntry so the next
|
||||||
|
// resolve() finds the model without waiting for the poll loop.
|
||||||
|
{
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
if let Some(node) = nodes.get_mut(node_name) {
|
||||||
|
node.models.insert(
|
||||||
|
profile.id.clone(),
|
||||||
|
cortex_core::node::ModelEntry {
|
||||||
|
id: profile.id.clone(),
|
||||||
|
status: ModelStatus::Loaded,
|
||||||
|
last_accessed: Some(chrono::Utc::now()),
|
||||||
|
vram_estimate_mb: profile.vram_mb,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Translate a `ModelProfile` to a `ModelSpec` neuron's /models/load
|
||||||
|
/// accepts. Devices are picked from the neuron's discovered topology —
|
||||||
|
/// the first `min_devices` indices that meet `min_device_vram_mb`.
|
||||||
|
async fn profile_to_spec(
|
||||||
|
fleet: &Arc<CortexState>,
|
||||||
|
node_name: &str,
|
||||||
|
profile: &ModelProfile,
|
||||||
|
) -> ModelSpec {
|
||||||
|
let devices = {
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let mut picked: Vec<u32> = Vec::new();
|
||||||
|
if let Some(node) = nodes.get(node_name)
|
||||||
|
&& let Some(disc) = &node.discovery
|
||||||
|
{
|
||||||
|
let min_vram = profile.min_device_vram_mb.unwrap_or(0);
|
||||||
|
for d in &disc.devices {
|
||||||
|
if d.vram_total_mb >= min_vram {
|
||||||
|
picked.push(d.index);
|
||||||
|
if picked.len() as u32 >= profile.min_devices {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if picked.is_empty() {
|
||||||
|
// Fall back to a 0..min_devices default; pick_feasible_neuron
|
||||||
|
// already verified the topology satisfies the constraints,
|
||||||
|
// so this only fires if discovery raced or was lost.
|
||||||
|
(0..profile.min_devices).collect()
|
||||||
|
} else {
|
||||||
|
picked
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let tensor_parallel = if profile.min_devices > 1 {
|
||||||
|
Some(profile.min_devices)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
ModelSpec {
|
||||||
|
model_id: profile.id.clone(),
|
||||||
|
harness: profile.harness.clone(),
|
||||||
|
quant: profile.quant.clone(),
|
||||||
|
tensor_parallel,
|
||||||
|
devices: Some(devices),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
|
||||||
|
/// build the final `RouteDecision`. Shared by all three priority
|
||||||
|
/// branches above.
|
||||||
|
async fn finish(
|
||||||
|
fleet: &Arc<CortexState>,
|
||||||
|
node_name: &str,
|
||||||
|
neuron_endpoint: &str,
|
||||||
|
model_id: &str,
|
||||||
|
cold_start: bool,
|
||||||
|
) -> Result<RouteDecision, RouteError> {
|
||||||
let endpoint_url = format!(
|
let endpoint_url = format!(
|
||||||
"{}/models/{}/endpoint",
|
"{}/models/{}/endpoint",
|
||||||
neuron_endpoint,
|
neuron_endpoint,
|
||||||
@@ -89,13 +299,82 @@ pub async fn resolve(
|
|||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let endpoint = inference_endpoint.ok_or_else(|| {
|
let raw = inference_endpoint.ok_or_else(|| {
|
||||||
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone())
|
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.to_string())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
// Rewrite loopback inference URLs to use the configured neuron host.
|
||||||
|
// Neuron's default bind_url is `http://localhost:13131` (it can't
|
||||||
|
// reliably know its own externally-resolvable name). Cortex sees a
|
||||||
|
// URL that's only meaningful from the neuron host's own perspective;
|
||||||
|
// proxying directly to localhost from a different cortex host would
|
||||||
|
// hit nothing. Keep neuron's port and path (a future harness could
|
||||||
|
// serve inference on a different port than the management API), but
|
||||||
|
// swap the host for the one in cortex.toml.
|
||||||
|
let endpoint = rewrite_loopback_host(&raw, neuron_endpoint).unwrap_or(raw);
|
||||||
|
|
||||||
Ok(RouteDecision {
|
Ok(RouteDecision {
|
||||||
node_name,
|
node_name: node_name.to_string(),
|
||||||
endpoint,
|
endpoint,
|
||||||
cold_start,
|
cold_start,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// If `inference_url`'s host is a loopback name (localhost / 127.0.0.1 /
|
||||||
|
/// 0.0.0.0 / ::1), return a copy with the host replaced by
|
||||||
|
/// `neuron_endpoint`'s host. Otherwise return None and the caller falls
|
||||||
|
/// back to the inference URL as-is.
|
||||||
|
fn rewrite_loopback_host(inference_url: &str, neuron_endpoint: &str) -> Option<String> {
|
||||||
|
let inf = url::Url::parse(inference_url).ok()?;
|
||||||
|
let inf_host = inf.host_str()?;
|
||||||
|
let is_loopback = matches!(inf_host, "localhost" | "127.0.0.1" | "0.0.0.0" | "::1");
|
||||||
|
if !is_loopback {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let neuron = url::Url::parse(neuron_endpoint).ok()?;
|
||||||
|
let new_host = neuron.host_str()?;
|
||||||
|
let mut out = inf.clone();
|
||||||
|
out.set_host(Some(new_host)).ok()?;
|
||||||
|
// url::Url::to_string normalises an empty path to "/", which then
|
||||||
|
// breaks downstream callers that do format!("{endpoint}/v1/...")
|
||||||
|
// and produce a double slash. The proxy URL is treated as a base
|
||||||
|
// string that the caller appends paths to, so strip the trailing
|
||||||
|
// slash here.
|
||||||
|
let s = out.to_string();
|
||||||
|
Some(s.trim_end_matches('/').to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::rewrite_loopback_host;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rewrites_localhost_keeps_port_and_path() {
|
||||||
|
let out = rewrite_loopback_host(
|
||||||
|
"http://localhost:13131",
|
||||||
|
"http://beast.hanzalova.internal:13131",
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
out.as_deref(),
|
||||||
|
Some("http://beast.hanzalova.internal:13131")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rewrites_loopback_with_distinct_inference_port() {
|
||||||
|
let out = rewrite_loopback_host("http://127.0.0.1:8080", "http://beast.lan:13131");
|
||||||
|
assert_eq!(out.as_deref(), Some("http://beast.lan:8080"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn leaves_non_loopback_alone() {
|
||||||
|
let out = rewrite_loopback_host("http://other.host:1234", "http://beast.lan:13131");
|
||||||
|
assert_eq!(out, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn malformed_inference_url_returns_none() {
|
||||||
|
let out = rewrite_loopback_host("not a url", "http://beast.lan:13131");
|
||||||
|
assert_eq!(out, None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ impl CortexState {
|
|||||||
models: HashMap::new(),
|
models: HashMap::new(),
|
||||||
lifecycle_cycles: 0,
|
lifecycle_cycles: 0,
|
||||||
last_poll: None,
|
last_poll: None,
|
||||||
|
discovery: None,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ use tokio::net::TcpListener;
|
|||||||
/// - GET /models/:id/endpoint (returns the inference URL)
|
/// - GET /models/:id/endpoint (returns the inference URL)
|
||||||
/// - POST /models/unload (accepts unload requests)
|
/// - POST /models/unload (accepts unload requests)
|
||||||
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
|
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
|
||||||
|
///
|
||||||
/// Returns the neuron base URL.
|
/// Returns the neuron base URL.
|
||||||
pub async fn spawn_mock_neuron() -> String {
|
pub async fn spawn_mock_neuron() -> String {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
@@ -54,7 +55,7 @@ pub async fn spawn_mock_neuron() -> String {
|
|||||||
|
|
||||||
async fn mock_neuron_list_models() -> Json<Value> {
|
async fn mock_neuron_list_models() -> Json<Value> {
|
||||||
Json(json!([
|
Json(json!([
|
||||||
{"id": "test-model", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
{"id": "test-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
||||||
]))
|
]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ use std::sync::Arc;
|
|||||||
async fn test_poller_discovers_models() {
|
async fn test_poller_discovers_models() {
|
||||||
// Mock neuron reports 2 models via /models endpoint (neuron format).
|
// Mock neuron reports 2 models via /models endpoint (neuron format).
|
||||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
{"id": "model-a", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
{"id": "model-a", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
||||||
{"id": "model-b", "harness": "mistralrs", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
{"id": "model-b", "harness": "candle", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
||||||
]))
|
]))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -63,8 +63,8 @@ async fn test_poller_discovers_models() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_updates_gateway_models_endpoint() {
|
async fn test_poller_updates_gateway_models_endpoint() {
|
||||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
{"id": "model-x", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
{"id": "model-x", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||||
{"id": "model-y", "harness": "mistralrs", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
{"id": "model-y", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
||||||
]))
|
]))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -152,8 +152,8 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_removes_stale_models() {
|
async fn test_poller_removes_stale_models() {
|
||||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||||
{"id": "drop-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
{"id": "drop-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||||
]))
|
]))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -183,7 +183,7 @@ async fn test_poller_removes_stale_models() {
|
|||||||
|
|
||||||
// New mock with only one model.
|
// New mock with only one model.
|
||||||
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
|
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||||
]))
|
]))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|||||||
@@ -51,18 +51,18 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
chunks.len() >= chunk_count + 1,
|
chunks.len() > chunk_count,
|
||||||
"expected at least {} chunks (got {}): {:?}",
|
"expected more than {} chunks (got {}): {:?}",
|
||||||
chunk_count + 1,
|
chunk_count,
|
||||||
chunks.len(),
|
chunks.len(),
|
||||||
chunks,
|
chunks,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
||||||
|
|
||||||
for i in 0..chunk_count {
|
for (i, chunk) in chunks.iter().enumerate().take(chunk_count) {
|
||||||
let chunk_json: serde_json::Value =
|
let chunk_json: serde_json::Value =
|
||||||
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
|
serde_json::from_str(chunk).expect("chunk should be valid JSON");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
chunk_json["choices"][0]["delta"]["content"],
|
chunk_json["choices"][0]["delta"]["content"],
|
||||||
format!("token{i}")
|
format!("token{i}")
|
||||||
|
|||||||
@@ -12,6 +12,36 @@ path = "src/lib.rs"
|
|||||||
name = "neuron"
|
name = "neuron"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
# Enables CUDA acceleration in candle and the cudarc/nccl bindings the
|
||||||
|
# TP worker pool uses. Without this feature, candle compiles for CPU
|
||||||
|
# only, Device::new_cuda calls fall back to CPU, and TP Init/sanity
|
||||||
|
# requests return Error{kind="cuda_feature_not_enabled"}.
|
||||||
|
cuda = [
|
||||||
|
"candle-core/cuda",
|
||||||
|
"candle-core/nccl",
|
||||||
|
"candle-nn/cuda",
|
||||||
|
"candle-transformers/cuda",
|
||||||
|
"dep:cudarc",
|
||||||
|
"dep:half",
|
||||||
|
"dep:cudaforge",
|
||||||
|
]
|
||||||
|
# Use cuDNN for convolution / attention kernels. Requires CUDA.
|
||||||
|
cudnn = [
|
||||||
|
"cuda",
|
||||||
|
"candle-core/cudnn",
|
||||||
|
"candle-nn/cudnn",
|
||||||
|
"candle-transformers/cudnn",
|
||||||
|
]
|
||||||
|
# FlashAttention kernels. Requires CUDA.
|
||||||
|
flash-attn = [
|
||||||
|
"cuda",
|
||||||
|
"candle-transformers/flash-attn",
|
||||||
|
]
|
||||||
|
# Reserved for GPU-only integration tests in later stages.
|
||||||
|
cuda-integration = ["cuda"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
cortex-core.workspace = true
|
cortex-core.workspace = true
|
||||||
tokio.workspace = true
|
tokio.workspace = true
|
||||||
@@ -24,9 +54,44 @@ tracing-subscriber.workspace = true
|
|||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
clap.workspace = true
|
clap.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
tokio-stream.workspace = true
|
||||||
figment.workspace = true
|
figment.workspace = true
|
||||||
toml.workspace = true
|
toml.workspace = true
|
||||||
|
|
||||||
|
# candle for in-process inference. CUDA support is gated behind the
|
||||||
|
# crate's `cuda` feature (default off) so the workspace builds on
|
||||||
|
# non-CUDA hosts and CI runners.
|
||||||
|
candle-core = "0.10.2"
|
||||||
|
candle-nn = "0.10.2"
|
||||||
|
candle-transformers = "0.10.2"
|
||||||
|
# Direct dep on cudarc (matching candle's transitive version) so the
|
||||||
|
# TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on
|
||||||
|
# the `cuda` feature; same toolchain requirement as candle's CUDA path.
|
||||||
|
cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] }
|
||||||
|
# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle
|
||||||
|
# storages. Matches candle-core's pinned major version to avoid double-
|
||||||
|
# compiling the `half` crate at conflicting versions.
|
||||||
|
half = { version = "2.5", optional = true }
|
||||||
|
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
|
||||||
|
hf-hub = { version = "0.4", features = ["tokio"] }
|
||||||
|
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
|
||||||
|
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
|
||||||
|
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
||||||
|
# without materialising the full tensor on device.
|
||||||
|
safetensors = "0.7"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { workspace = true, features = ["test-util"] }
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
reqwest.workspace = true
|
reqwest.workspace = true
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a`
|
||||||
|
# under the `cuda` feature. Matches mistralrs's upstream build setup
|
||||||
|
# (their `mistralrs-core/build.rs` uses the same constructor).
|
||||||
|
cudaforge = { version = "0.1", optional = true }
|
||||||
|
|
||||||
|
[package.metadata.docs.rs]
|
||||||
|
# Skip the CUDA path on docs.rs (it lacks nvcc).
|
||||||
|
no-default-features = true
|
||||||
|
|||||||
66
crates/neuron/build.rs
Normal file
66
crates/neuron/build.rs
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//! Build script: compile the CUDA kernels in `src/cuda/*.cu` into a
|
||||||
|
//! static library and link it under the `cuda` feature.
|
||||||
|
//!
|
||||||
|
//! Patterned on `EricLBuehler/mistral.rs::mistralrs-core/build.rs` —
|
||||||
|
//! same `cudaforge::KernelBuilder` invocation, same NVCC flag set.
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
use std::path::PathBuf;
|
||||||
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
println!("cargo:rerun-if-changed=src/cuda/");
|
||||||
|
|
||||||
|
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||||
|
|
||||||
|
let mut builder = cudaforge::KernelBuilder::new()
|
||||||
|
.source_glob("src/cuda/*.cu")
|
||||||
|
.out_dir(&build_dir)
|
||||||
|
.arg("-std=c++17")
|
||||||
|
.arg("-O3")
|
||||||
|
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||||
|
.arg("--expt-relaxed-constexpr")
|
||||||
|
.arg("--expt-extended-lambda")
|
||||||
|
.arg("--use_fast_math")
|
||||||
|
.arg("--compiler-options")
|
||||||
|
.arg("-fPIC");
|
||||||
|
|
||||||
|
// sm_<80 doesn't have bf16 intrinsics for WMMA — gate the
|
||||||
|
// bf16-only kernels off in that case. (Mirrors upstream.)
|
||||||
|
if let Some(compute_cap) = builder.get_compute_cap()
|
||||||
|
&& compute_cap < 80
|
||||||
|
{
|
||||||
|
builder = builder.arg("-DNO_BF16_KERNEL");
|
||||||
|
}
|
||||||
|
|
||||||
|
let target = std::env::var("TARGET").unwrap();
|
||||||
|
let out_file = if target.contains("msvc") {
|
||||||
|
build_dir.join("neuroncuda.lib")
|
||||||
|
} else {
|
||||||
|
build_dir.join("libneuroncuda.a")
|
||||||
|
};
|
||||||
|
|
||||||
|
builder
|
||||||
|
.build_lib(out_file)
|
||||||
|
.expect("neuron cuda build failed");
|
||||||
|
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||||
|
println!("cargo:rustc-link-lib=neuroncuda");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
|
|
||||||
|
if target.contains("msvc") {
|
||||||
|
// No extra runtime library needed.
|
||||||
|
} else if target.contains("apple")
|
||||||
|
|| target.contains("freebsd")
|
||||||
|
|| target.contains("openbsd")
|
||||||
|
{
|
||||||
|
println!("cargo:rustc-link-lib=dylib=c++");
|
||||||
|
} else if target.contains("android") {
|
||||||
|
println!("cargo:rustc-link-lib=dylib=c++_shared");
|
||||||
|
} else {
|
||||||
|
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,23 +1,33 @@
|
|||||||
//! HTTP API handlers for the neuron daemon.
|
//! HTTP API handlers for the neuron daemon.
|
||||||
|
|
||||||
use crate::harness::HarnessRegistry;
|
use crate::harness::HarnessRegistry;
|
||||||
|
use crate::harness::candle::{CandleHarness, InferenceError};
|
||||||
use crate::health::HealthCache;
|
use crate::health::HealthCache;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use axum::extract::{Path, State};
|
use axum::extract::{Path, State};
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::{IntoResponse, Json};
|
use axum::response::{IntoResponse, Json};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||||
use cortex_core::harness::ModelSpec;
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use cortex_core::openai::ChatCompletionRequest;
|
||||||
|
use futures::stream::{self, StreamExt};
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
|
use std::convert::Infallible;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
|
||||||
/// Shared state for the neuron HTTP server.
|
/// Shared state for the neuron HTTP server.
|
||||||
pub struct NeuronState {
|
pub struct NeuronState {
|
||||||
pub discovery: DiscoveryResponse,
|
pub discovery: DiscoveryResponse,
|
||||||
pub health_cache: Arc<HealthCache>,
|
pub health_cache: Arc<HealthCache>,
|
||||||
pub registry: RwLock<HarnessRegistry>,
|
pub registry: RwLock<HarnessRegistry>,
|
||||||
|
/// Typed handle to the candle harness for inference routes. Cached at
|
||||||
|
/// startup so `/v1/chat/completions` doesn't have to hold the registry
|
||||||
|
/// read lock or perform dyn-Trait dispatch per request.
|
||||||
|
pub candle: Option<Arc<CandleHarness>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the neuron API router.
|
/// Build the neuron API router.
|
||||||
@@ -29,6 +39,7 @@ pub fn neuron_routes() -> Router<Arc<NeuronState>> {
|
|||||||
.route("/models/load", post(load_model))
|
.route("/models/load", post(load_model))
|
||||||
.route("/models/unload", post(unload_model))
|
.route("/models/unload", post(unload_model))
|
||||||
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
||||||
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
||||||
@@ -45,7 +56,7 @@ async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse
|
|||||||
Ok(models) => Json(json!(models)).into_response(),
|
Ok(models) => Json(json!(models)).into_response(),
|
||||||
Err(e) => (
|
Err(e) => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(json!({"error": e.to_string()})),
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
}
|
}
|
||||||
@@ -58,11 +69,22 @@ async fn load_model(
|
|||||||
let registry = state.registry.read().await;
|
let registry = state.registry.read().await;
|
||||||
match registry.load_model(&spec).await {
|
match registry.load_model(&spec).await {
|
||||||
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
|
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
|
||||||
Err(e) => (
|
Err(e) => {
|
||||||
|
// Log the full anyhow chain server-side so journalctl shows
|
||||||
|
// the underlying failure (hf-hub timeout, permission denied,
|
||||||
|
// disk full, etc.) without needing to inspect the HTTP
|
||||||
|
// response body separately.
|
||||||
|
tracing::warn!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
"load_model failed"
|
||||||
|
);
|
||||||
|
(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(json!({"error": e.to_string()})),
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,7 +106,11 @@ async fn unload_model(
|
|||||||
let registry = state.registry.read().await;
|
let registry = state.registry.read().await;
|
||||||
match registry.unload_model(&model_id).await {
|
match registry.unload_model(&model_id).await {
|
||||||
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
|
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
|
||||||
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
|
Err(e) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,3 +128,61 @@ async fn model_endpoint(
|
|||||||
.into_response(),
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// OpenAI-compatible chat completions. Dispatches to streaming SSE when
|
||||||
|
/// `stream: true` is set on the request; otherwise returns a single
|
||||||
|
/// `ChatCompletionResponse`.
|
||||||
|
async fn chat_completions(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Json(req): Json<ChatCompletionRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||||
|
return (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({"error": "candle harness not enabled on this neuron"})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
};
|
||||||
|
|
||||||
|
if req.stream.unwrap_or(false) {
|
||||||
|
match candle.chat_completion_stream(req).await {
|
||||||
|
Ok(rx) => {
|
||||||
|
// Each chunk → one SSE `data: {json}` line. After the
|
||||||
|
// channel closes, append the OpenAI [DONE] terminator.
|
||||||
|
let body_stream = ReceiverStream::new(rx).map(|chunk| {
|
||||||
|
let body = serde_json::to_string(&chunk).unwrap_or_default();
|
||||||
|
Ok::<_, Infallible>(Event::default().data(body))
|
||||||
|
});
|
||||||
|
let done_stream =
|
||||||
|
stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
|
||||||
|
Sse::new(body_stream.chain(done_stream))
|
||||||
|
.keep_alive(KeepAlive::default())
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::Other(e)) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match candle.chat_completion(req).await {
|
||||||
|
Ok(resp) => Json(resp).into_response(),
|
||||||
|
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::Other(e)) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
//! Neuron configuration loaded from neuron.toml.
|
//! Neuron configuration loaded from neuron.toml.
|
||||||
|
|
||||||
use cortex_core::harness::HarnessConfig;
|
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||||
use figment::{
|
use figment::{
|
||||||
Figment,
|
Figment,
|
||||||
providers::{Env, Format, Toml},
|
providers::{Env, Format, Toml},
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct NeuronConfig {
|
pub struct NeuronConfig {
|
||||||
@@ -14,10 +14,35 @@ pub struct NeuronConfig {
|
|||||||
pub port: u16,
|
pub port: u16,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub harnesses: Vec<HarnessConfig>,
|
pub harnesses: Vec<HarnessConfig>,
|
||||||
|
/// Per-harness configuration. Currently only `candle` is recognised.
|
||||||
|
#[serde(default)]
|
||||||
|
pub harness: HarnessSettings,
|
||||||
|
/// Models to auto-load when the neuron service activates. Each entry
|
||||||
|
/// is loaded sequentially before the HTTP listener binds. A failure
|
||||||
|
/// on any single entry logs a warning and proceeds — broken entries
|
||||||
|
/// don't prevent the rest of the fleet from starting.
|
||||||
|
#[serde(default)]
|
||||||
|
pub default_models: Vec<ModelSpec>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Settings for individual harness implementations. Each harness owns
|
||||||
|
/// its own sub-table so users only configure the harnesses they enable.
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
pub struct HarnessSettings {
|
||||||
|
#[serde(default)]
|
||||||
|
pub candle: CandleHarnessConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
pub struct CandleHarnessConfig {
|
||||||
|
/// HuggingFace cache directory for model weights.
|
||||||
|
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
|
||||||
|
#[serde(default)]
|
||||||
|
pub hf_cache: Option<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_port() -> u16 {
|
fn default_port() -> u16 {
|
||||||
9090
|
13131
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NeuronConfig {
|
impl NeuronConfig {
|
||||||
@@ -33,8 +58,10 @@ impl NeuronConfig {
|
|||||||
impl Default for NeuronConfig {
|
impl Default for NeuronConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
port: 9090,
|
port: 13131,
|
||||||
harnesses: vec![],
|
harnesses: vec![],
|
||||||
|
harness: HarnessSettings::default(),
|
||||||
|
default_models: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
84
crates/neuron/src/cuda/ffi.rs
Normal file
84
crates/neuron/src/cuda/ffi.rs
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
//! FFI declarations for the CUDA kernels in `gdn.cu`.
|
||||||
|
//!
|
||||||
|
//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs`
|
||||||
|
//! covering only the Gated DeltaNet kernels we currently use. Other
|
||||||
|
//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective
|
||||||
|
//! scan, etc.) would land here too as we absorb them.
|
||||||
|
//!
|
||||||
|
//! All function declarations are MIT-licensed from upstream and
|
||||||
|
//! unchanged apart from this header.
|
||||||
|
|
||||||
|
use std::ffi::c_void;
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
unsafe extern "C" {
|
||||||
|
// GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next.
|
||||||
|
pub(crate) fn gated_delta_rule_recurrence(
|
||||||
|
q: *const f32,
|
||||||
|
k: *const f32,
|
||||||
|
v: *const f32,
|
||||||
|
g: *const f32,
|
||||||
|
beta: *const f32,
|
||||||
|
state: *mut f32,
|
||||||
|
output: *mut f32,
|
||||||
|
bh: i32,
|
||||||
|
seq_len: i32,
|
||||||
|
k_dim: i32,
|
||||||
|
v_dim: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks).
|
||||||
|
pub(crate) fn chunked_gated_delta_rule_recurrence(
|
||||||
|
q: *const f32,
|
||||||
|
k: *const f32,
|
||||||
|
v: *const f32,
|
||||||
|
g: *const f32,
|
||||||
|
beta: *const f32,
|
||||||
|
state: *mut f32,
|
||||||
|
output: *mut f32,
|
||||||
|
bh: i32,
|
||||||
|
seq_len: i32,
|
||||||
|
k_dim: i32,
|
||||||
|
v_dim: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub(crate) fn causal_conv1d_update(
|
||||||
|
x: *const c_void,
|
||||||
|
weight: *const c_void,
|
||||||
|
conv_state: *mut c_void,
|
||||||
|
output: *mut c_void,
|
||||||
|
batch_size: i32,
|
||||||
|
conv_dim: i32,
|
||||||
|
kernel_size: i32,
|
||||||
|
dtype: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub(crate) fn causal_conv1d_full(
|
||||||
|
x: *const c_void,
|
||||||
|
weight: *const c_void,
|
||||||
|
conv_state_out: *mut c_void,
|
||||||
|
output: *mut c_void,
|
||||||
|
batch_size: i32,
|
||||||
|
conv_dim: i32,
|
||||||
|
seq_len: i32,
|
||||||
|
kernel_size: i32,
|
||||||
|
dtype: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub(crate) fn fused_gdn_gating(
|
||||||
|
b: *const c_void,
|
||||||
|
a: *const c_void,
|
||||||
|
a_log: *const f32,
|
||||||
|
dt_bias: *const f32,
|
||||||
|
beta_out: *mut c_void,
|
||||||
|
g_out: *mut c_void,
|
||||||
|
total_elements: i32,
|
||||||
|
num_heads: i32,
|
||||||
|
dtype: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
}
|
||||||
711
crates/neuron/src/cuda/gdn.cu
Normal file
711
crates/neuron/src/cuda/gdn.cu
Normal file
@@ -0,0 +1,711 @@
|
|||||||
|
// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`).
|
||||||
|
//
|
||||||
|
// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||||
|
// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this
|
||||||
|
// file are limited to this banner; the kernels are unchanged so a
|
||||||
|
// diff against upstream stays minimal.
|
||||||
|
//
|
||||||
|
// Five kernels exposed via `extern "C"` shims at the bottom:
|
||||||
|
// - gated_delta_rule_recurrence (per-token decode)
|
||||||
|
// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill)
|
||||||
|
// - causal_conv1d_update (single-token conv decode)
|
||||||
|
// - causal_conv1d_full (multi-token conv prefill)
|
||||||
|
// - fused_gdn_gating (beta = sigmoid(b);
|
||||||
|
// g = -exp(A_log) * softplus(a + dt_bias))
|
||||||
|
|
||||||
|
#include "cuda_bf16.h"
|
||||||
|
#include "cuda_fp16.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 1: gated_delta_rule_recurrence (optimized)
|
||||||
|
//
|
||||||
|
// V-tiled recurrence with compile-time K dimension for register residency.
|
||||||
|
// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of
|
||||||
|
// state. Shared memory holds k_buf and q_buf (2*BK floats).
|
||||||
|
//
|
||||||
|
// Optimizations over naive version:
|
||||||
|
// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30)
|
||||||
|
// - #pragma unroll on all k-loops -> full ILP
|
||||||
|
// - Fused decay+kv_mem pass and fused state_update+output pass
|
||||||
|
// - __fmaf_rn intrinsics for guaranteed fused multiply-add
|
||||||
|
// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere
|
||||||
|
//
|
||||||
|
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||||
|
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// Optimized kernel: BK known at compile time -> registers + full unrolling
|
||||||
|
template <int BK, int BV>
|
||||||
|
__global__ void gated_delta_rule_recurrence_kernel_tiled(
|
||||||
|
const float *__restrict__ q, // [BH, S, K]
|
||||||
|
const float *__restrict__ k, // [BH, S, K]
|
||||||
|
const float *__restrict__ v, // [BH, S, V]
|
||||||
|
const float *__restrict__ g, // [BH, S]
|
||||||
|
const float *__restrict__ beta, // [BH, S]
|
||||||
|
float *__restrict__ state, // [BH, K, V]
|
||||||
|
float *__restrict__ output, // [BH, S, V]
|
||||||
|
int seq_len, int v_dim) {
|
||||||
|
|
||||||
|
const int v_tile = blockIdx.x; // which V-tile
|
||||||
|
const int bh = blockIdx.y; // batch*head index
|
||||||
|
const int tid = threadIdx.x; // thread within tile [0, BV)
|
||||||
|
const int v_idx = v_tile * BV + tid; // global V index
|
||||||
|
|
||||||
|
if (v_idx >= v_dim)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Pointers for this (batch, head)
|
||||||
|
const float *q_bh = q + bh * seq_len * BK;
|
||||||
|
const float *k_bh = k + bh * seq_len * BK;
|
||||||
|
const float *v_bh = v + bh * seq_len * v_dim;
|
||||||
|
const float *g_bh = g + bh * seq_len;
|
||||||
|
const float *beta_bh = beta + bh * seq_len;
|
||||||
|
float *state_bh = state + bh * BK * v_dim;
|
||||||
|
float *out_bh = output + bh * seq_len * v_dim;
|
||||||
|
|
||||||
|
// Shared memory: k_buf[BK] + q_buf[BK]
|
||||||
|
__shared__ float k_buf[BK];
|
||||||
|
__shared__ float q_buf[BK];
|
||||||
|
|
||||||
|
// Load state column into registers — BK is compile-time, so this is
|
||||||
|
// a true register array (not spilled to local memory)
|
||||||
|
float s[BK];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] = state_bh[j * v_dim + v_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = 0; t < seq_len; t++) {
|
||||||
|
// Collaboratively load k_t into shared memory
|
||||||
|
// BK / BV loads per thread (e.g. 128/64 = 2)
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
k_buf[j] = k_bh[t * BK + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Load scalars for this timestep
|
||||||
|
float decay = expf(g_bh[t]);
|
||||||
|
float beta_t = beta_bh[t];
|
||||||
|
float v_t = v_bh[t * v_dim + v_idx];
|
||||||
|
|
||||||
|
// Fused pass 1: decay state + compute kv_mem
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] *= decay;
|
||||||
|
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delta rule
|
||||||
|
float delta = (v_t - kv_mem) * beta_t;
|
||||||
|
|
||||||
|
// Collaboratively load q_t into shared memory
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
q_buf[j] = q_bh[t * BK + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Fused pass 2: update state + compute output
|
||||||
|
float y_t = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||||
|
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_bh[t * v_dim + v_idx] = y_t;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write state back
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
state_bh[j * v_dim + v_idx] = s[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback kernel: runtime k_dim, still V-tiled for occupancy
|
||||||
|
template <int BV, int MAX_K>
|
||||||
|
__global__ void gated_delta_rule_recurrence_kernel_fallback(
|
||||||
|
const float *__restrict__ q, const float *__restrict__ k,
|
||||||
|
const float *__restrict__ v, const float *__restrict__ g,
|
||||||
|
const float *__restrict__ beta, float *__restrict__ state,
|
||||||
|
float *__restrict__ output, int seq_len, int k_dim, int v_dim) {
|
||||||
|
|
||||||
|
const int v_tile = blockIdx.x;
|
||||||
|
const int bh = blockIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int v_idx = v_tile * BV + tid;
|
||||||
|
|
||||||
|
if (v_idx >= v_dim)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const float *q_bh = q + bh * seq_len * k_dim;
|
||||||
|
const float *k_bh = k + bh * seq_len * k_dim;
|
||||||
|
const float *v_bh = v + bh * seq_len * v_dim;
|
||||||
|
const float *g_bh = g + bh * seq_len;
|
||||||
|
const float *beta_bh = beta + bh * seq_len;
|
||||||
|
float *state_bh = state + bh * k_dim * v_dim;
|
||||||
|
float *out_bh = output + bh * seq_len * v_dim;
|
||||||
|
|
||||||
|
extern __shared__ float shared[];
|
||||||
|
float *k_buf = shared;
|
||||||
|
float *q_buf = shared + k_dim;
|
||||||
|
|
||||||
|
float s[MAX_K];
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
s[j] = state_bh[j * v_dim + v_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = 0; t < seq_len; t++) {
|
||||||
|
for (int j = tid; j < k_dim; j += BV) {
|
||||||
|
k_buf[j] = k_bh[t * k_dim + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float decay = expf(g_bh[t]);
|
||||||
|
float beta_t = beta_bh[t];
|
||||||
|
float v_t = v_bh[t * v_dim + v_idx];
|
||||||
|
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
s[j] *= decay;
|
||||||
|
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||||
|
}
|
||||||
|
|
||||||
|
float delta = (v_t - kv_mem) * beta_t;
|
||||||
|
|
||||||
|
for (int j = tid; j < k_dim; j += BV) {
|
||||||
|
q_buf[j] = q_bh[t * k_dim + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float y_t = 0.0f;
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||||
|
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_bh[t * v_dim + v_idx] = y_t;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
state_bh[j * v_dim + v_idx] = s[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void gated_delta_rule_recurrence(const float *q, const float *k,
|
||||||
|
const float *v, const float *g,
|
||||||
|
const float *beta, float *state,
|
||||||
|
float *output, int bh, int seq_len,
|
||||||
|
int k_dim, int v_dim,
|
||||||
|
int64_t stream) {
|
||||||
|
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
if (k_dim == 128) {
|
||||||
|
// Fast path for Qwen3-Next (k_dim=128)
|
||||||
|
constexpr int BK = 128;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||||
|
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||||
|
v_dim);
|
||||||
|
} else if (k_dim == 64) {
|
||||||
|
// Fast path for models with k_dim=64
|
||||||
|
constexpr int BK = 64;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||||
|
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||||
|
v_dim);
|
||||||
|
} else {
|
||||||
|
// Fallback for other k_dim values (runtime loop, still V-tiled)
|
||||||
|
constexpr int BV = 64;
|
||||||
|
constexpr int MAX_K = 256;
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
size_t smem = 2 * k_dim * sizeof(float);
|
||||||
|
gated_delta_rule_recurrence_kernel_fallback<BV, MAX_K>
|
||||||
|
<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||||
|
seq_len, k_dim, v_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization)
|
||||||
|
//
|
||||||
|
// Processes prefill tokens in BT-token chunks instead of one at a time.
|
||||||
|
// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation,
|
||||||
|
// forward substitution (triangular solve), output computation, and state
|
||||||
|
// update.
|
||||||
|
//
|
||||||
|
// Same thread model as Kernel 1: one block per (v_tile, batch*head),
|
||||||
|
// one thread per V-column. Each thread owns BK registers of state.
|
||||||
|
//
|
||||||
|
// Shared memory holds:
|
||||||
|
// k_chunk[BT * BK] -- key vectors for current chunk
|
||||||
|
// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix
|
||||||
|
// gcum[BT] -- cumulative sum of g within chunk
|
||||||
|
// beta_s[BT] -- beta values for chunk
|
||||||
|
// q_buf[BK] -- q vector (loaded one row at a time)
|
||||||
|
//
|
||||||
|
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||||
|
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <int BT, int BK, int BV>
|
||||||
|
__global__ void
|
||||||
|
chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K]
|
||||||
|
const float *__restrict__ k, // [BH, S, K]
|
||||||
|
const float *__restrict__ v, // [BH, S, V]
|
||||||
|
const float *__restrict__ g, // [BH, S]
|
||||||
|
const float *__restrict__ beta, // [BH, S]
|
||||||
|
float *__restrict__ state, // [BH, K, V]
|
||||||
|
float *__restrict__ output, // [BH, S, V]
|
||||||
|
int seq_len, int v_dim) {
|
||||||
|
|
||||||
|
const int v_tile = blockIdx.x;
|
||||||
|
const int bh = blockIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int v_idx = v_tile * BV + tid;
|
||||||
|
|
||||||
|
if (v_idx >= v_dim)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const int num_chunks = (seq_len + BT - 1) / BT;
|
||||||
|
|
||||||
|
// Pointers for this (batch, head)
|
||||||
|
const float *q_bh = q + bh * seq_len * BK;
|
||||||
|
const float *k_bh = k + bh * seq_len * BK;
|
||||||
|
const float *v_bh = v + bh * seq_len * v_dim;
|
||||||
|
const float *g_bh = g + bh * seq_len;
|
||||||
|
const float *beta_bh = beta + bh * seq_len;
|
||||||
|
float *state_bh = state + bh * BK * v_dim;
|
||||||
|
float *out_bh = output + bh * seq_len * v_dim;
|
||||||
|
|
||||||
|
// Dynamic shared memory layout
|
||||||
|
extern __shared__ float smem[];
|
||||||
|
float *k_chunk = smem; // [BT * BK]
|
||||||
|
float *kk_dot = smem + BT * BK; // [BT * BT]
|
||||||
|
float *gcum = smem + BT * BK + BT * BT; // [BT]
|
||||||
|
float *beta_s = gcum + BT; // [BT]
|
||||||
|
float *q_buf = beta_s + BT; // [BK]
|
||||||
|
|
||||||
|
// Load state column into registers
|
||||||
|
float s[BK];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] = state_bh[j * v_dim + v_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-thread register array for corrected deltas
|
||||||
|
float delta[BT];
|
||||||
|
|
||||||
|
for (int c = 0; c < num_chunks; c++) {
|
||||||
|
const int chunk_start = c * BT;
|
||||||
|
const int chunk_len = min(BT, seq_len - chunk_start);
|
||||||
|
|
||||||
|
// === Phase 1: Cooperative load of k, beta, g into shared memory ===
|
||||||
|
for (int t = 0; t < chunk_len; t++) {
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tid < chunk_len) {
|
||||||
|
beta_s[tid] = beta_bh[chunk_start + tid];
|
||||||
|
gcum[tid] = g_bh[chunk_start + tid];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// === Phase 1b: Parallel prefix sum of g (Hillis-Steele) ===
|
||||||
|
for (int stride = 1; stride < BT; stride <<= 1) {
|
||||||
|
float prev = 0.0f;
|
||||||
|
if (tid < chunk_len && (int)tid >= stride)
|
||||||
|
prev = gcum[tid - stride];
|
||||||
|
__syncthreads();
|
||||||
|
if (tid < chunk_len && (int)tid >= stride)
|
||||||
|
gcum[tid] += prev;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i ===
|
||||||
|
// Only lower-triangular entries needed (strictly lower)
|
||||||
|
for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) {
|
||||||
|
int i = idx / chunk_len;
|
||||||
|
int j = idx % chunk_len;
|
||||||
|
if (j < i) {
|
||||||
|
float dot = 0.0f;
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot);
|
||||||
|
}
|
||||||
|
kk_dot[i * BT + j] = dot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// === Phase 3: Forward substitution (per V-column, in registers) ===
|
||||||
|
// Computes corrected delta values via triangular solve
|
||||||
|
for (int i = 0; i < chunk_len; i++) {
|
||||||
|
float v_i = v_bh[(chunk_start + i) * v_dim + v_idx];
|
||||||
|
float decay_i = expf(gcum[i]);
|
||||||
|
float beta_i = beta_s[i];
|
||||||
|
|
||||||
|
// Inter-chunk contribution: state @ k[i] with decay
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem);
|
||||||
|
}
|
||||||
|
|
||||||
|
float rhs = beta_i * (v_i - kv_mem);
|
||||||
|
|
||||||
|
// Subtract lower-triangular contributions (intra-chunk)
|
||||||
|
for (int j = 0; j < i; j++) {
|
||||||
|
float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]);
|
||||||
|
rhs -= a_ij * delta[j];
|
||||||
|
}
|
||||||
|
delta[i] = rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Phase 4: Output computation (per V-column) ===
|
||||||
|
for (int i = 0; i < chunk_len; i++) {
|
||||||
|
// Cooperatively load q[i] into shared
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
q_buf[j] = q_bh[(chunk_start + i) * BK + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float decay_i = expf(gcum[i]);
|
||||||
|
|
||||||
|
// Inter-chunk contribution: q[i] @ (state * decay)
|
||||||
|
float o_val = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] *
|
||||||
|
// exp(gcum[i] - gcum[j])
|
||||||
|
for (int j = 0; j <= i; j++) {
|
||||||
|
float qk_dot = 0.0f;
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot);
|
||||||
|
}
|
||||||
|
o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_bh[(chunk_start + i) * v_dim + v_idx] = o_val;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Phase 5: State update for next chunk ===
|
||||||
|
float g_total = gcum[chunk_len - 1];
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
float s_new = s[d] * expf(g_total);
|
||||||
|
for (int t = 0; t < chunk_len; t++) {
|
||||||
|
s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]);
|
||||||
|
}
|
||||||
|
s[d] = s_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write final state back
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
state_bh[j * v_dim + v_idx] = s[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void chunked_gated_delta_rule_recurrence(
|
||||||
|
const float *q, const float *k, const float *v, const float *g,
|
||||||
|
const float *beta, float *state, float *output, int bh, int seq_len,
|
||||||
|
int k_dim, int v_dim, int64_t stream) {
|
||||||
|
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
if (k_dim == 128) {
|
||||||
|
constexpr int BT = 64;
|
||||||
|
constexpr int BK = 128;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
// Shared memory: BT*BK + BT*BT + BT + BT + BK floats
|
||||||
|
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||||
|
|
||||||
|
// Request extended shared memory
|
||||||
|
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||||
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
smem);
|
||||||
|
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||||
|
seq_len, v_dim);
|
||||||
|
} else if (k_dim == 64) {
|
||||||
|
constexpr int BT = 64;
|
||||||
|
constexpr int BK = 64;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||||
|
|
||||||
|
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||||
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
smem);
|
||||||
|
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||||
|
seq_len, v_dim);
|
||||||
|
} else {
|
||||||
|
// Fallback: use the sequential kernel for unsupported k_dim
|
||||||
|
gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len,
|
||||||
|
k_dim, v_dim, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 2a: causal_conv1d_update (decode path, single step)
|
||||||
|
//
|
||||||
|
// Each thread handles one channel: shift conv_state left by 1,
|
||||||
|
// insert new value, dot product with weight, apply SiLU.
|
||||||
|
//
|
||||||
|
// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||||
|
// conv_state: [B, conv_dim, kernel_size] (in/out)
|
||||||
|
// output: [B, conv_dim, 1]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void causal_conv1d_update_kernel(
|
||||||
|
const T *__restrict__ x, // [B, conv_dim, 1]
|
||||||
|
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||||
|
T *__restrict__ conv_state, // [B, conv_dim, kernel_size]
|
||||||
|
T *__restrict__ output, // [B, conv_dim, 1]
|
||||||
|
int batch_size, int conv_dim, int kernel_size) {
|
||||||
|
|
||||||
|
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int b = blockIdx.y;
|
||||||
|
|
||||||
|
if (ch >= conv_dim || b >= batch_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Pointer to this batch/channel's conv state
|
||||||
|
T *cs = conv_state + (b * conv_dim + ch) * kernel_size;
|
||||||
|
const T *w = weight + ch * kernel_size;
|
||||||
|
|
||||||
|
// Shift state left by 1
|
||||||
|
for (int i = 0; i < kernel_size - 1; i++) {
|
||||||
|
cs[i] = cs[i + 1];
|
||||||
|
}
|
||||||
|
// Insert new value
|
||||||
|
cs[kernel_size - 1] = x[b * conv_dim + ch];
|
||||||
|
|
||||||
|
// Dot product with weight
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (int i = 0; i < kernel_size; i++) {
|
||||||
|
acc += (float)cs[i] * (float)w[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// SiLU activation: x * sigmoid(x)
|
||||||
|
float sig = 1.0f / (1.0f + expf(-acc));
|
||||||
|
float result = acc * sig;
|
||||||
|
|
||||||
|
output[b * conv_dim + ch] = (T)result;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void causal_conv1d_update(const void *x, const void *weight,
|
||||||
|
void *conv_state, void *output,
|
||||||
|
int batch_size, int conv_dim,
|
||||||
|
int kernel_size, int dtype,
|
||||||
|
int64_t stream) {
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((conv_dim + 255) / 256, batch_size);
|
||||||
|
|
||||||
|
if (dtype == 0) {
|
||||||
|
// f16
|
||||||
|
causal_conv1d_update_kernel<__half><<<grid, block, 0, custream>>>(
|
||||||
|
(const __half *)x, (const __half *)weight, (__half *)conv_state,
|
||||||
|
(__half *)output, batch_size, conv_dim, kernel_size);
|
||||||
|
} else {
|
||||||
|
// bf16
|
||||||
|
causal_conv1d_update_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||||
|
(__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size,
|
||||||
|
conv_dim, kernel_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 2b: causal_conv1d_full (prefill path)
|
||||||
|
//
|
||||||
|
// Each thread handles one (channel, position): causal window with
|
||||||
|
// zero-padding, dot product with weight, SiLU.
|
||||||
|
// A second pass writes the conv_state from the last kernel_size positions.
|
||||||
|
//
|
||||||
|
// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||||
|
// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void causal_conv1d_full_kernel(
|
||||||
|
const T *__restrict__ x, // [B, conv_dim, S]
|
||||||
|
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||||
|
T *__restrict__ output, // [B, conv_dim, S]
|
||||||
|
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||||
|
|
||||||
|
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int pos = blockIdx.y;
|
||||||
|
const int b = blockIdx.z;
|
||||||
|
|
||||||
|
if (ch >= conv_dim || pos >= seq_len || b >= batch_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||||
|
const T *w = weight + ch * kernel_size;
|
||||||
|
|
||||||
|
// Causal convolution: sum over kernel_size window ending at pos
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (int i = 0; i < kernel_size; i++) {
|
||||||
|
int src_pos = pos - (kernel_size - 1) + i;
|
||||||
|
float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f;
|
||||||
|
acc += x_val * (float)w[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// SiLU
|
||||||
|
float sig = 1.0f / (1.0f + expf(-acc));
|
||||||
|
float result = acc * sig;
|
||||||
|
|
||||||
|
output[(b * conv_dim + ch) * seq_len + pos] = (T)result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void save_conv_state_kernel(
|
||||||
|
const T *__restrict__ x, // [B, conv_dim, S]
|
||||||
|
T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size]
|
||||||
|
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||||
|
|
||||||
|
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int b = blockIdx.y;
|
||||||
|
|
||||||
|
if (ch >= conv_dim || b >= batch_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||||
|
T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size;
|
||||||
|
|
||||||
|
// Save last kernel_size positions (zero-pad if seq_len < kernel_size)
|
||||||
|
int pad = kernel_size - seq_len;
|
||||||
|
for (int i = 0; i < kernel_size; i++) {
|
||||||
|
if (i < pad) {
|
||||||
|
cs[i] = (T)0.0f;
|
||||||
|
} else {
|
||||||
|
cs[i] = x_bch[seq_len - kernel_size + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void causal_conv1d_full(const void *x, const void *weight,
|
||||||
|
void *conv_state_out, void *output,
|
||||||
|
int batch_size, int conv_dim, int seq_len,
|
||||||
|
int kernel_size, int dtype, int64_t stream) {
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
// Main convolution kernel
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((conv_dim + 255) / 256, seq_len, batch_size);
|
||||||
|
|
||||||
|
if (dtype == 0) {
|
||||||
|
causal_conv1d_full_kernel<__half><<<grid, block, 0, custream>>>(
|
||||||
|
(const __half *)x, (const __half *)weight, (__half *)output, batch_size,
|
||||||
|
conv_dim, seq_len, kernel_size);
|
||||||
|
// Save conv state
|
||||||
|
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||||
|
save_conv_state_kernel<__half><<<grid2, block, 0, custream>>>(
|
||||||
|
(const __half *)x, (__half *)conv_state_out, batch_size, conv_dim,
|
||||||
|
seq_len, kernel_size);
|
||||||
|
} else {
|
||||||
|
causal_conv1d_full_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||||
|
(__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size);
|
||||||
|
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||||
|
save_conv_state_kernel<__nv_bfloat16><<<grid2, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size,
|
||||||
|
conv_dim, seq_len, kernel_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 3: fused_gdn_gating
|
||||||
|
//
|
||||||
|
// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||||
|
// a_log and dt_bias are per-head (broadcast over batch*seq).
|
||||||
|
//
|
||||||
|
// b, a: [total] a_log, dt_bias: [num_heads]
|
||||||
|
// beta_out, g_out: [total]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void
|
||||||
|
fused_gdn_gating_kernel(const T *__restrict__ b, // [total]
|
||||||
|
const T *__restrict__ a, // [total]
|
||||||
|
const float *__restrict__ a_log, // [num_heads]
|
||||||
|
const float *__restrict__ dt_bias, // [num_heads]
|
||||||
|
T *__restrict__ beta_out, // [total]
|
||||||
|
T *__restrict__ g_out, // [total]
|
||||||
|
int total_elements, int num_heads) {
|
||||||
|
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx >= total_elements)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Head index: elements are laid out as [..., num_heads]
|
||||||
|
int head_idx = idx % num_heads;
|
||||||
|
|
||||||
|
// beta = sigmoid(b)
|
||||||
|
float b_val = (float)b[idx];
|
||||||
|
float beta = 1.0f / (1.0f + expf(-b_val));
|
||||||
|
|
||||||
|
// g = -exp(a_log) * softplus(a + dt_bias)
|
||||||
|
float a_val = (float)a[idx];
|
||||||
|
float a_log_val = a_log[head_idx];
|
||||||
|
float dt_bias_val = dt_bias[head_idx];
|
||||||
|
|
||||||
|
float sp_input = a_val + dt_bias_val;
|
||||||
|
float softplus_val = logf(1.0f + expf(sp_input));
|
||||||
|
float g_val = -expf(a_log_val) * softplus_val;
|
||||||
|
|
||||||
|
beta_out[idx] = (T)beta;
|
||||||
|
g_out[idx] = (T)g_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void fused_gdn_gating(const void *b, const void *a,
|
||||||
|
const float *a_log, const float *dt_bias,
|
||||||
|
void *beta_out, void *g_out,
|
||||||
|
int total_elements, int num_heads, int dtype,
|
||||||
|
int64_t stream) {
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((total_elements + 255) / 256);
|
||||||
|
|
||||||
|
if (dtype == 0) {
|
||||||
|
fused_gdn_gating_kernel<__half><<<grid, block, 0, custream>>>(
|
||||||
|
(const __half *)b, (const __half *)a, a_log, dt_bias,
|
||||||
|
(__half *)beta_out, (__half *)g_out, total_elements, num_heads);
|
||||||
|
} else {
|
||||||
|
fused_gdn_gating_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias,
|
||||||
|
(__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements,
|
||||||
|
num_heads);
|
||||||
|
}
|
||||||
|
}
|
||||||
486
crates/neuron/src/cuda/gdn.rs
Normal file
486
crates/neuron/src/cuda/gdn.rs
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`.
|
||||||
|
//!
|
||||||
|
//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||||
|
//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in
|
||||||
|
//! this file are this header comment — the FFI path module name is
|
||||||
|
//! `crate::cuda::ffi`, identical to upstream's layout.
|
||||||
|
|
||||||
|
#![allow(clippy::cast_possible_truncation)]
|
||||||
|
|
||||||
|
use candle_core::{Result, Tensor};
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use candle_core::DType;
|
||||||
|
|
||||||
|
/// CUDA-accelerated gated delta rule recurrence.
|
||||||
|
///
|
||||||
|
/// Inputs (all contiguous, f32):
|
||||||
|
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||||
|
/// state: [BH, K, V] (mutated in place)
|
||||||
|
///
|
||||||
|
/// Returns: output [BH, S, V]
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn gated_delta_rule_recurrence_cuda(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
|
||||||
|
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||||
|
let v_dim = v.dim(2)?;
|
||||||
|
|
||||||
|
let dev = q.device().as_cuda_device()?;
|
||||||
|
|
||||||
|
let (q_s, q_l) = q.storage_and_layout();
|
||||||
|
let q_s = match &*q_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("q must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let q_offset = q_l.start_offset();
|
||||||
|
|
||||||
|
let (k_s, k_l) = k.storage_and_layout();
|
||||||
|
let k_s = match &*k_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("k must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let k_offset = k_l.start_offset();
|
||||||
|
|
||||||
|
let (v_s, v_l) = v.storage_and_layout();
|
||||||
|
let v_s = match &*v_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("v must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let v_offset = v_l.start_offset();
|
||||||
|
|
||||||
|
let (g_s, g_l) = g.storage_and_layout();
|
||||||
|
let g_s = match &*g_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("g must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let g_offset = g_l.start_offset();
|
||||||
|
|
||||||
|
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||||
|
let beta_s = match &*beta_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("beta must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let beta_offset = beta_l.start_offset();
|
||||||
|
|
||||||
|
let (state_s, state_l) = state.storage_and_layout();
|
||||||
|
let state_s = match &*state_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("state must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let state_offset = state_l.start_offset();
|
||||||
|
|
||||||
|
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::gated_delta_rule_recurrence(
|
||||||
|
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||||
|
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||||
|
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||||
|
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||||
|
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||||
|
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||||
|
bh as i32,
|
||||||
|
seq_len as i32,
|
||||||
|
k_dim as i32,
|
||||||
|
v_dim as i32,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The kernel wrote state in-place via the raw pointer; rewrap
|
||||||
|
// (state tensor's underlying CudaSlice was modified directly)
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
Ok(Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(bh, seq_len, v_dim),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn gated_delta_rule_recurrence_cuda(
|
||||||
|
_q: &Tensor,
|
||||||
|
_k: &Tensor,
|
||||||
|
_v: &Tensor,
|
||||||
|
_g: &Tensor,
|
||||||
|
_beta: &Tensor,
|
||||||
|
_state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization).
|
||||||
|
///
|
||||||
|
/// Processes prefill tokens in 64-token chunks instead of one at a time.
|
||||||
|
/// Same interface as `gated_delta_rule_recurrence_cuda`.
|
||||||
|
///
|
||||||
|
/// Inputs (all contiguous, f32):
|
||||||
|
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||||
|
/// state: [BH, K, V] (mutated in place)
|
||||||
|
///
|
||||||
|
/// Returns: output [BH, S, V]
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
|
||||||
|
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||||
|
let v_dim = v.dim(2)?;
|
||||||
|
|
||||||
|
let dev = q.device().as_cuda_device()?;
|
||||||
|
|
||||||
|
let (q_s, q_l) = q.storage_and_layout();
|
||||||
|
let q_s = match &*q_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("q must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let q_offset = q_l.start_offset();
|
||||||
|
|
||||||
|
let (k_s, k_l) = k.storage_and_layout();
|
||||||
|
let k_s = match &*k_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("k must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let k_offset = k_l.start_offset();
|
||||||
|
|
||||||
|
let (v_s, v_l) = v.storage_and_layout();
|
||||||
|
let v_s = match &*v_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("v must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let v_offset = v_l.start_offset();
|
||||||
|
|
||||||
|
let (g_s, g_l) = g.storage_and_layout();
|
||||||
|
let g_s = match &*g_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("g must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let g_offset = g_l.start_offset();
|
||||||
|
|
||||||
|
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||||
|
let beta_s = match &*beta_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("beta must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let beta_offset = beta_l.start_offset();
|
||||||
|
|
||||||
|
let (state_s, state_l) = state.storage_and_layout();
|
||||||
|
let state_s = match &*state_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("state must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let state_offset = state_l.start_offset();
|
||||||
|
|
||||||
|
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::chunked_gated_delta_rule_recurrence(
|
||||||
|
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||||
|
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||||
|
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||||
|
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||||
|
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||||
|
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||||
|
bh as i32,
|
||||||
|
seq_len as i32,
|
||||||
|
k_dim as i32,
|
||||||
|
v_dim as i32,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
Ok(Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(bh, seq_len, v_dim),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||||
|
_q: &Tensor,
|
||||||
|
_k: &Tensor,
|
||||||
|
_v: &Tensor,
|
||||||
|
_g: &Tensor,
|
||||||
|
_beta: &Tensor,
|
||||||
|
_state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA-accelerated causal conv1d (both update and full paths).
|
||||||
|
///
|
||||||
|
/// For update (is_update=true):
|
||||||
|
/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||||
|
/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update)
|
||||||
|
/// Returns: (output [B, conv_dim, 1], updated conv_state)
|
||||||
|
///
|
||||||
|
/// For full (is_update=false):
|
||||||
|
/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||||
|
/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size])
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn causal_conv1d_cuda(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: &Tensor,
|
||||||
|
kernel_size: usize,
|
||||||
|
is_update: bool,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
use core::ffi::c_void;
|
||||||
|
fn cuda_fwd<
|
||||||
|
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||||
|
>(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: &Tensor,
|
||||||
|
kernel_size: usize,
|
||||||
|
is_update: bool,
|
||||||
|
dtype_code: i32,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let dev = x.device().as_cuda_device()?;
|
||||||
|
let (batch_size, conv_dim, seq_len) = x.dims3()?;
|
||||||
|
|
||||||
|
let (x_s, x_l) = x.storage_and_layout();
|
||||||
|
let x_s = match &*x_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("x must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let x_offset = x_l.start_offset();
|
||||||
|
|
||||||
|
let (w_s, w_l) = weight.storage_and_layout();
|
||||||
|
let w_s = match &*w_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("weight must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let w_offset = w_l.start_offset();
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
if is_update {
|
||||||
|
// Clone conv_state so the kernel can mutate it in place
|
||||||
|
let conv_state_new = conv_state.clone();
|
||||||
|
|
||||||
|
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim) }?;
|
||||||
|
|
||||||
|
// Scope the borrow of conv_state_new so we can move it later
|
||||||
|
{
|
||||||
|
let (cs_s, cs_l) = conv_state_new.storage_and_layout();
|
||||||
|
let cs_s = match &*cs_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("conv_state must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let cs_offset = cs_l.start_offset();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::causal_conv1d_update(
|
||||||
|
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||||
|
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||||
|
cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||||
|
batch_size as i32,
|
||||||
|
conv_dim as i32,
|
||||||
|
kernel_size as i32,
|
||||||
|
dtype_code,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
let output = Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(batch_size, conv_dim, 1usize),
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok((output, conv_state_new))
|
||||||
|
} else {
|
||||||
|
// Full path: allocate new conv_state and output
|
||||||
|
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * seq_len) }?;
|
||||||
|
let cs_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * kernel_size) }?;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::causal_conv1d_full(
|
||||||
|
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||||
|
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||||
|
cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||||
|
batch_size as i32,
|
||||||
|
conv_dim as i32,
|
||||||
|
seq_len as i32,
|
||||||
|
kernel_size as i32,
|
||||||
|
dtype_code,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
let output = Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(batch_size, conv_dim, seq_len),
|
||||||
|
));
|
||||||
|
|
||||||
|
let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone());
|
||||||
|
let new_conv_state = Tensor::from((
|
||||||
|
candle::Storage::Cuda(cs_storage),
|
||||||
|
(batch_size, conv_dim, kernel_size),
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok((output, new_conv_state))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match x.dtype() {
|
||||||
|
DType::F16 => cuda_fwd::<half::f16>(x, weight, conv_state, kernel_size, is_update, 0),
|
||||||
|
DType::BF16 => cuda_fwd::<half::bf16>(x, weight, conv_state, kernel_size, is_update, 1),
|
||||||
|
other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn causal_conv1d_cuda(
|
||||||
|
_x: &Tensor,
|
||||||
|
_weight: &Tensor,
|
||||||
|
_conv_state: &Tensor,
|
||||||
|
_kernel_size: usize,
|
||||||
|
_is_update: bool,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
candle_core::bail!("causal_conv1d_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA-accelerated fused GDN gating computation.
|
||||||
|
///
|
||||||
|
/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||||
|
///
|
||||||
|
/// b, a: [total_elements] in f16/bf16
|
||||||
|
/// a_log, dt_bias: [num_heads] in f32
|
||||||
|
///
|
||||||
|
/// Returns: (beta, g) in original dtype
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn fused_gdn_gating_cuda(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
use core::ffi::c_void;
|
||||||
|
|
||||||
|
fn cuda_fwd<
|
||||||
|
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||||
|
>(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
dtype_code: i32,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let total_elements = b.elem_count();
|
||||||
|
let num_heads = a_log.elem_count();
|
||||||
|
let shape = b.shape().clone();
|
||||||
|
let dev = b.device().as_cuda_device()?;
|
||||||
|
|
||||||
|
let (b_s, b_l) = b.storage_and_layout();
|
||||||
|
let b_s = match &*b_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("b must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let b_offset = b_l.start_offset();
|
||||||
|
|
||||||
|
let (a_s, a_l) = a.storage_and_layout();
|
||||||
|
let a_s = match &*a_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("a must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let a_offset = a_l.start_offset();
|
||||||
|
|
||||||
|
let (alog_s, alog_l) = a_log.storage_and_layout();
|
||||||
|
let alog_s = match &*alog_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("a_log must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let alog_offset = alog_l.start_offset();
|
||||||
|
|
||||||
|
let (dtb_s, dtb_l) = dt_bias.storage_and_layout();
|
||||||
|
let dtb_s = match &*dtb_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("dt_bias must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let dtb_offset = dtb_l.start_offset();
|
||||||
|
|
||||||
|
let beta_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||||
|
let g_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::fused_gdn_gating(
|
||||||
|
b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void,
|
||||||
|
a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void,
|
||||||
|
alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32,
|
||||||
|
dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32,
|
||||||
|
beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void,
|
||||||
|
g_buf.device_ptr(g_buf.stream()).0 as *mut c_void,
|
||||||
|
total_elements as i32,
|
||||||
|
num_heads as i32,
|
||||||
|
dtype_code,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone());
|
||||||
|
let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone()));
|
||||||
|
|
||||||
|
let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone());
|
||||||
|
let g = Tensor::from((candle::Storage::Cuda(g_storage), shape));
|
||||||
|
|
||||||
|
Ok((beta, g))
|
||||||
|
}
|
||||||
|
|
||||||
|
match b.dtype() {
|
||||||
|
DType::F16 => cuda_fwd::<half::f16>(b, a, a_log, dt_bias, 0),
|
||||||
|
DType::BF16 => cuda_fwd::<half::bf16>(b, a, a_log, dt_bias, 1),
|
||||||
|
other => candle_core::bail!(
|
||||||
|
"fused_gdn_gating_cuda only supports f16/bf16, got {:?}",
|
||||||
|
other
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn fused_gdn_gating_cuda(
|
||||||
|
_b: &Tensor,
|
||||||
|
_a: &Tensor,
|
||||||
|
_a_log: &Tensor,
|
||||||
|
_dt_bias: &Tensor,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
15
crates/neuron/src/cuda/mod.rs
Normal file
15
crates/neuron/src/cuda/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//! CUDA kernels and their Rust wrappers.
|
||||||
|
//!
|
||||||
|
//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`)
|
||||||
|
//! inference performance — the Gated DeltaNet kernels ported from
|
||||||
|
//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu`
|
||||||
|
//! file alongside this module; `build.rs` compiles them all into a
|
||||||
|
//! static lib via `cudaforge` and links it under the `cuda` feature.
|
||||||
|
//!
|
||||||
|
//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM,
|
||||||
|
//! etc.) they land here in their own `.cu` + `.rs` pairs.
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub mod ffi;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub mod gdn;
|
||||||
23
crates/neuron/src/harness/arch/mod.rs
Normal file
23
crates/neuron/src/harness/arch/mod.rs
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
//! Custom architecture implementations.
|
||||||
|
//!
|
||||||
|
//! When candle-transformers ships a model family unchanged
|
||||||
|
//! (`models::llama`, `models::qwen3`, `models::qwen3_moe`, etc.), the
|
||||||
|
//! handler in `harness/candle.rs` just wraps the upstream type in a
|
||||||
|
//! `ModelArch` variant.
|
||||||
|
//!
|
||||||
|
//! When candle has nothing for the architecture and we have to write
|
||||||
|
//! it from scratch — Qwen3-Next / Qwen3.6 (`qwen3_5`) being the
|
||||||
|
//! motivating example — the implementation lands here, one file per
|
||||||
|
//! architecture.
|
||||||
|
//!
|
||||||
|
//! Each architecture module is expected to expose:
|
||||||
|
//! - A `Config` type deserialised from the model's `config.json`
|
||||||
|
//! (some architectures nest the real hyperparams under `text_config`,
|
||||||
|
//! in which case the module owns the unwrapping).
|
||||||
|
//! - A `ForCausalLM` struct with `new`, `forward(&mut self, x, offset)
|
||||||
|
//! -> Result<Tensor>`, and `clear_kv_cache(&mut self)`.
|
||||||
|
//!
|
||||||
|
//! TP-aware analogues live in `harness/tp/tp_<family>.rs` and follow
|
||||||
|
//! the pattern set by `tp_qwen3.rs`.
|
||||||
|
|
||||||
|
pub mod qwen3_5;
|
||||||
117
crates/neuron/src/harness/arch/qwen3_5/decoder.rs
Normal file
117
crates/neuron/src/harness/arch/qwen3_5/decoder.rs
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
//! Qwen3-Next decoder layer.
|
||||||
|
//!
|
||||||
|
//! Standard pre-norm transformer block (LN → attention → residual →
|
||||||
|
//! LN → MLP → residual) where the attention slot dispatches on the
|
||||||
|
//! per-layer `layer_types[i]` value in the config:
|
||||||
|
//!
|
||||||
|
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
|
||||||
|
//! gate + RoPE + KV cache).
|
||||||
|
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
|
||||||
|
//! causal conv + per-head state).
|
||||||
|
//!
|
||||||
|
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
|
||||||
|
//! linear_attention. `full_attention_interval` in the config is a
|
||||||
|
//! hint; `layer_types` is authoritative.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
use super::full_attn::Qwen3_5Attention;
|
||||||
|
use super::linear_attn::GatedDeltaNet;
|
||||||
|
use super::mlp::Qwen3_5MLP;
|
||||||
|
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||||
|
use super::rope::RotaryEmbedding;
|
||||||
|
|
||||||
|
/// One of the two attention flavours sitting in a decoder layer's
|
||||||
|
/// attention slot. Full-attention layers need the rotary table and
|
||||||
|
/// take an attention mask; linear-attention layers carry their own
|
||||||
|
/// recurrent state and ignore the mask.
|
||||||
|
enum AttentionKind {
|
||||||
|
Full(Qwen3_5Attention),
|
||||||
|
Linear(GatedDeltaNet),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Qwen3_5DecoderLayer {
|
||||||
|
input_layernorm: Qwen3_5RmsNorm,
|
||||||
|
post_attention_layernorm: Qwen3_5RmsNorm,
|
||||||
|
mlp: Qwen3_5MLP,
|
||||||
|
attention: AttentionKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5DecoderLayer {
|
||||||
|
pub fn load(
|
||||||
|
cfg: &TextConfig,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
layer_idx: usize,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let layer_type = cfg
|
||||||
|
.layer_types
|
||||||
|
.get(layer_idx)
|
||||||
|
.map(String::as_str)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"layer_types[{layer_idx}] missing (have {} entries)",
|
||||||
|
cfg.layer_types.len()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let attention = match layer_type {
|
||||||
|
"full_attention" => {
|
||||||
|
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
|
||||||
|
}
|
||||||
|
"linear_attention" => {
|
||||||
|
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unknown layer_type '{other}' for layer {layer_idx} (expected \
|
||||||
|
'full_attention' or 'linear_attention')"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
|
||||||
|
let input_layernorm =
|
||||||
|
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
||||||
|
&vb.pp("post_attention_layernorm"),
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
input_layernorm,
|
||||||
|
post_attention_layernorm,
|
||||||
|
mlp,
|
||||||
|
attention,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let h = self.input_layernorm.forward(x)?;
|
||||||
|
let attn_out = match &mut self.attention {
|
||||||
|
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
|
||||||
|
// Linear attention ignores attn_mask + offset; its causal
|
||||||
|
// structure is baked into the recurrent state lifecycle.
|
||||||
|
AttentionKind::Linear(net) => net.forward(&h)?,
|
||||||
|
};
|
||||||
|
let x = (x + attn_out)?;
|
||||||
|
let h2 = self.post_attention_layernorm.forward(&x)?;
|
||||||
|
let h2 = self.mlp.forward(&h2)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
match &mut self.attention {
|
||||||
|
AttentionKind::Full(attn) => attn.clear_kv_cache(),
|
||||||
|
AttentionKind::Linear(net) => net.clear_kv_cache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
179
crates/neuron/src/harness/arch/qwen3_5/full_attn.rs
Normal file
179
crates/neuron/src/harness/arch/qwen3_5/full_attn.rs
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
//! Qwen3-Next's `full_attention` layer.
|
||||||
|
//!
|
||||||
|
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
|
||||||
|
//!
|
||||||
|
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
|
||||||
|
//! to `num_heads * head_dim * 2`. The second half is reshaped to
|
||||||
|
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
|
||||||
|
//! attention output is pointwise-multiplied by this gate before
|
||||||
|
//! `o_proj`. Effectively a per-head per-position attenuation on
|
||||||
|
//! the attention output.
|
||||||
|
//!
|
||||||
|
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
|
||||||
|
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
|
||||||
|
//! checkpoints expect the `(1 + w)` form.
|
||||||
|
//!
|
||||||
|
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
|
||||||
|
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
|
||||||
|
//! `rope::RotaryEmbedding`), and the usual causal mask.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::kv_cache::ConcatKvCache;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use candle_transformers::utils::repeat_kv;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||||
|
use super::rope::RotaryEmbedding;
|
||||||
|
|
||||||
|
pub struct Qwen3_5Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
q_norm: Qwen3_5RmsNorm,
|
||||||
|
k_norm: Qwen3_5RmsNorm,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: ConcatKvCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5Attention {
|
||||||
|
pub fn load(
|
||||||
|
cfg: &TextConfig,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
|
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"num_attention_heads ({num_heads}) must be a positive multiple of \
|
||||||
|
num_key_value_heads ({num_kv_heads})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
|
|
||||||
|
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
|
||||||
|
// the gate (see attn_output_gate notes above).
|
||||||
|
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
|
||||||
|
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||||
|
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||||
|
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
|
||||||
|
|
||||||
|
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
let hidden_size = head_dim * num_heads;
|
||||||
|
let kv_cache = ConcatKvCache::new(2);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
hidden_size,
|
||||||
|
rotary,
|
||||||
|
kv_cache,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
|
// 1. q_proj — widened output, split into (query, gate).
|
||||||
|
let q_raw = self
|
||||||
|
.q_proj
|
||||||
|
.forward(x)?
|
||||||
|
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
|
||||||
|
let q = q_raw.narrow(3, 0, self.head_dim)?;
|
||||||
|
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
|
||||||
|
// Flatten the gate's head dim back into hidden_size for the
|
||||||
|
// post-attention pointwise multiply.
|
||||||
|
let gate = gate
|
||||||
|
.contiguous()?
|
||||||
|
.reshape((b, l, self.num_heads * self.head_dim))?;
|
||||||
|
|
||||||
|
// 2. q_norm + k_norm + reshape to (B, H, L, D).
|
||||||
|
let q = self.q_norm.forward(&q.contiguous()?)?;
|
||||||
|
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
|
||||||
|
|
||||||
|
let k = self
|
||||||
|
.k_proj
|
||||||
|
.forward(x)?
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
|
||||||
|
let k = self.k_norm.forward(&k.contiguous()?)?;
|
||||||
|
let k = k.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
|
let v = self
|
||||||
|
.v_proj
|
||||||
|
.forward(x)?
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
|
||||||
|
// 3. RoPE on q, k.
|
||||||
|
let (q, k) = self.rotary.apply(&q, &k, offset)?;
|
||||||
|
|
||||||
|
// 4. KV cache.
|
||||||
|
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||||
|
|
||||||
|
// 5. GQA repeat (cheap shape op).
|
||||||
|
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||||
|
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
|
// 6. Scaled dot-product + causal mask.
|
||||||
|
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
|
||||||
|
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
if let Some(m) = attn_mask {
|
||||||
|
scores = scores.broadcast_add(m)?;
|
||||||
|
}
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
let ctx = probs.matmul(&v)?; // (B, H, L, D)
|
||||||
|
|
||||||
|
// 7. Reshape back, apply the output gate, project.
|
||||||
|
let ctx = ctx
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?
|
||||||
|
.reshape((b, l, self.hidden_size))?;
|
||||||
|
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
|
||||||
|
let gated = (ctx * gate_sig)?;
|
||||||
|
self.o_proj.forward(&gated)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_linear_no_bias(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
name: &str,
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.pp(name)
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
793
crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs
Normal file
793
crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs
Normal file
@@ -0,0 +1,793 @@
|
|||||||
|
//! Qwen3-Next's `linear_attention` layer: Gated DeltaNet.
|
||||||
|
//!
|
||||||
|
//! The recurrent linear-attention block that occupies 3 out of every 4
|
||||||
|
//! decoder layers in Qwen3.6 (`layer_types[i] == "linear_attention"`).
|
||||||
|
//! Implemented against the reference Python in
|
||||||
|
//! `huggingface/transformers/src/transformers/models/qwen3_5/modeling_qwen3_5.py`
|
||||||
|
//! (class `Qwen3_5GatedDeltaNet`).
|
||||||
|
//!
|
||||||
|
//! ## Block structure
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! x ── in_proj_qkv ── transpose ─► (B, conv_dim, L)
|
||||||
|
//! │
|
||||||
|
//! ┌──────────────── conv_state ──┤ prepend cached state (decode)
|
||||||
|
//! ▼
|
||||||
|
//! depthwise causal Conv1d (k=4) → SiLU
|
||||||
|
//! │
|
||||||
|
//! └─ split → q (k_dim), k (k_dim), v (v_dim) ─► per-head reshape
|
||||||
|
//!
|
||||||
|
//! x ── in_proj_z ────────────────► z (gate for the output RMSNorm)
|
||||||
|
//! x ── in_proj_b ── sigmoid ─────► beta (per-head per-token update rate)
|
||||||
|
//! x ── in_proj_a ── softplus ────► g (decay; see eqn below)
|
||||||
|
//!
|
||||||
|
//! g = -exp(A_log) * softplus(a + dt_bias) # discretisation
|
||||||
|
//! beta = sigmoid(b)
|
||||||
|
//!
|
||||||
|
//! (q, k) ─── L2norm ─── delta rule loop ──── core_attn_out
|
||||||
|
//! (per-token, per-head):
|
||||||
|
//! state *= exp(g_t)
|
||||||
|
//! mem = state^T · k_t
|
||||||
|
//! delta = (v_t - mem) * beta_t
|
||||||
|
//! state += outer(k_t, delta)
|
||||||
|
//! out_t = state^T · q_t
|
||||||
|
//!
|
||||||
|
//! core_attn_out ── RMSNormGated(z) ── reshape ── out_proj ── y
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! ## State
|
||||||
|
//!
|
||||||
|
//! Two tensors persist across decode steps:
|
||||||
|
//! - `conv_state`: `(B, conv_dim, conv_kernel_size)` — left-padded
|
||||||
|
//! tail of the input to the depthwise conv, so the next causal
|
||||||
|
//! window has the right left-context.
|
||||||
|
//! - `recurrent_state`: `(B, num_v_heads, head_k_dim, head_v_dim)` —
|
||||||
|
//! the delta-rule outer-product memory.
|
||||||
|
//!
|
||||||
|
//! Both are cleared via [`GatedDeltaNet::clear_kv_cache`] at the start
|
||||||
|
//! of every new request.
|
||||||
|
//!
|
||||||
|
//! ## Performance note
|
||||||
|
//!
|
||||||
|
//! This impl is the **recurrent** delta-rule for both prefill and
|
||||||
|
//! decode — i.e. the algorithm in `torch_recurrent_gated_delta_rule`.
|
||||||
|
//! Correctness-first. The chunked algorithm (chunk_size=64) in
|
||||||
|
//! `torch_chunk_gated_delta_rule` is a perf optimisation for long
|
||||||
|
//! prefill; can be added later without changing the surface.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
use super::RopeParameters;
|
||||||
|
use super::TextConfig;
|
||||||
|
use super::rmsnorm::{Qwen3_5RmsNormGated, l2norm};
|
||||||
|
|
||||||
|
/// Per-rank, per-layer state for the linear-attention block.
|
||||||
|
///
|
||||||
|
/// `conv_state` is left-padded with zeros on first use; `recurrent_state`
|
||||||
|
/// is initialised lazily to zeros once we know the batch size.
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct GatedDeltaNetState {
|
||||||
|
pub conv_state: Option<Tensor>,
|
||||||
|
pub recurrent_state: Option<Tensor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct GatedDeltaNet {
|
||||||
|
// Projections.
|
||||||
|
in_proj_qkv: Linear,
|
||||||
|
in_proj_z: Linear,
|
||||||
|
in_proj_b: Linear,
|
||||||
|
in_proj_a: Linear,
|
||||||
|
out_proj: Linear,
|
||||||
|
|
||||||
|
// Depthwise causal Conv1d weight; shape (conv_dim, 1, kernel_size).
|
||||||
|
// No bias (Python sets bias=False).
|
||||||
|
conv1d_weight: Tensor,
|
||||||
|
|
||||||
|
// Per-head discretisation params.
|
||||||
|
dt_bias: Tensor,
|
||||||
|
a_log: Tensor,
|
||||||
|
|
||||||
|
// Output norm + gate.
|
||||||
|
norm: Qwen3_5RmsNormGated,
|
||||||
|
|
||||||
|
// Shape hyperparams (cached for forward).
|
||||||
|
num_v_heads: usize,
|
||||||
|
num_k_heads: usize,
|
||||||
|
head_k_dim: usize,
|
||||||
|
head_v_dim: usize,
|
||||||
|
key_dim: usize,
|
||||||
|
value_dim: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
|
||||||
|
// Recurrent state held inline. Each request resets via
|
||||||
|
// `clear_kv_cache`; otherwise the state persists across forwards
|
||||||
|
// and the per-token offset advances naturally.
|
||||||
|
state: GatedDeltaNetState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GatedDeltaNet {
|
||||||
|
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let num_v_heads = cfg.linear_num_value_heads;
|
||||||
|
let num_k_heads = cfg.linear_num_key_heads;
|
||||||
|
let head_k_dim = cfg.linear_key_head_dim;
|
||||||
|
let head_v_dim = cfg.linear_value_head_dim;
|
||||||
|
let conv_kernel_size = cfg.linear_conv_kernel_dim;
|
||||||
|
|
||||||
|
if num_v_heads == 0 || num_k_heads == 0 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Qwen3-Next linear_num_*_heads must be set; got v={num_v_heads}, k={num_k_heads}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !num_v_heads.is_multiple_of(num_k_heads) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"linear_num_value_heads ({num_v_heads}) must be a multiple of \
|
||||||
|
linear_num_key_heads ({num_k_heads}) for GQA-style head expansion"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let key_dim = head_k_dim * num_k_heads;
|
||||||
|
let value_dim = head_v_dim * num_v_heads;
|
||||||
|
let conv_dim = key_dim * 2 + value_dim;
|
||||||
|
|
||||||
|
// ----- Linear projections (all `bias=False` in the reference). -----
|
||||||
|
let in_proj_qkv = load_linear_no_bias(vb, "in_proj_qkv", cfg.hidden_size, conv_dim)?;
|
||||||
|
let in_proj_z = load_linear_no_bias(vb, "in_proj_z", cfg.hidden_size, value_dim)?;
|
||||||
|
let in_proj_b = load_linear_no_bias(vb, "in_proj_b", cfg.hidden_size, num_v_heads)?;
|
||||||
|
let in_proj_a = load_linear_no_bias(vb, "in_proj_a", cfg.hidden_size, num_v_heads)?;
|
||||||
|
let out_proj = load_linear_no_bias(vb, "out_proj", value_dim, cfg.hidden_size)?;
|
||||||
|
|
||||||
|
// ----- Conv1d weight (depthwise, bias=False). -----
|
||||||
|
let conv1d_weight = vb
|
||||||
|
.pp("conv1d")
|
||||||
|
.get((conv_dim, 1, conv_kernel_size), "weight")
|
||||||
|
.with_context(|| format!("load '{}/conv1d/weight'", vb.prefix()))?;
|
||||||
|
|
||||||
|
// ----- dt_bias + A_log: per-head 1D params. -----
|
||||||
|
let dt_bias = vb
|
||||||
|
.get(num_v_heads, "dt_bias")
|
||||||
|
.with_context(|| format!("load '{}/dt_bias'", vb.prefix()))?;
|
||||||
|
let a_log = vb
|
||||||
|
.get(num_v_heads, "A_log")
|
||||||
|
.with_context(|| format!("load '{}/A_log'", vb.prefix()))?;
|
||||||
|
|
||||||
|
// ----- Output gated RMSNorm (per-head_v_dim). -----
|
||||||
|
let norm = Qwen3_5RmsNormGated::load(&vb.pp("norm"), head_v_dim, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
in_proj_qkv,
|
||||||
|
in_proj_z,
|
||||||
|
in_proj_b,
|
||||||
|
in_proj_a,
|
||||||
|
out_proj,
|
||||||
|
conv1d_weight,
|
||||||
|
dt_bias,
|
||||||
|
a_log,
|
||||||
|
norm,
|
||||||
|
num_v_heads,
|
||||||
|
num_k_heads,
|
||||||
|
head_k_dim,
|
||||||
|
head_v_dim,
|
||||||
|
key_dim,
|
||||||
|
value_dim,
|
||||||
|
conv_dim,
|
||||||
|
conv_kernel_size,
|
||||||
|
state: GatedDeltaNetState::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.state = GatedDeltaNetState::default();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `x` shape: `(B, L, hidden_size)`. Returns the same shape.
|
||||||
|
pub fn forward(&mut self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let (batch_size, seq_len, _) = x.dims3()?;
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let device = x.device().clone();
|
||||||
|
|
||||||
|
// ----- Projections. -----
|
||||||
|
// mixed_qkv: (B, L, conv_dim)
|
||||||
|
let mixed_qkv = self.in_proj_qkv.forward(x)?;
|
||||||
|
// (B, conv_dim, L) for the conv1d.
|
||||||
|
let mixed_qkv_chw = mixed_qkv.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
|
// z: (B, L, value_dim) → (B, L, num_v_heads, head_v_dim)
|
||||||
|
let z = self.in_proj_z.forward(x)?.reshape((
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
self.num_v_heads,
|
||||||
|
self.head_v_dim,
|
||||||
|
))?;
|
||||||
|
|
||||||
|
// b, a: (B, L, num_v_heads)
|
||||||
|
let b = self.in_proj_b.forward(x)?;
|
||||||
|
let a = self.in_proj_a.forward(x)?;
|
||||||
|
|
||||||
|
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
|
||||||
|
// Dispatches to a cuda kernel that fuses conv1d + silu when
|
||||||
|
// available; falls back to candle's `conv1d` + `silu` on cpu.
|
||||||
|
let (conv_out, new_state) = run_causal_conv1d(
|
||||||
|
&mixed_qkv_chw,
|
||||||
|
&self.conv1d_weight,
|
||||||
|
self.state.conv_state.take(),
|
||||||
|
batch_size,
|
||||||
|
self.conv_dim,
|
||||||
|
seq_len,
|
||||||
|
self.conv_kernel_size,
|
||||||
|
)?;
|
||||||
|
self.state.conv_state = Some(new_state);
|
||||||
|
// Back to (B, L, conv_dim).
|
||||||
|
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
|
// ----- Split into q, k, v. -----
|
||||||
|
let q = mixed_qkv.narrow(2, 0, self.key_dim)?;
|
||||||
|
let k = mixed_qkv.narrow(2, self.key_dim, self.key_dim)?;
|
||||||
|
let v = mixed_qkv.narrow(2, 2 * self.key_dim, self.value_dim)?;
|
||||||
|
|
||||||
|
let q = q.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
|
||||||
|
let k = k.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
|
||||||
|
let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?;
|
||||||
|
|
||||||
|
// ----- beta + g (per-head, per-token gates). -----
|
||||||
|
// Fused on cuda; per-op Rust on cpu. Both paths produce:
|
||||||
|
// beta = sigmoid(b)
|
||||||
|
// g = -exp(A_log) * softplus(a + dt_bias)
|
||||||
|
let (beta, g) = run_fused_gating(&b, &a, &self.a_log, &self.dt_bias)?;
|
||||||
|
|
||||||
|
// ----- GQA-style key expansion if num_v_heads > num_k_heads. -----
|
||||||
|
let (q, k) = if self.num_v_heads > self.num_k_heads {
|
||||||
|
let rep = self.num_v_heads / self.num_k_heads;
|
||||||
|
(
|
||||||
|
repeat_interleave(&q, rep, 2)?,
|
||||||
|
repeat_interleave(&k, rep, 2)?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(q, k)
|
||||||
|
};
|
||||||
|
|
||||||
|
// ----- L2-norm on q, k (use_qk_l2norm_in_kernel=True in ref). -----
|
||||||
|
let q = l2norm(&q, 1e-6)?;
|
||||||
|
let k = l2norm(&k, 1e-6)?;
|
||||||
|
|
||||||
|
// ----- Recurrent delta rule. -----
|
||||||
|
// Inputs: q, k (B, L, H, D_k); v (B, L, H, D_v); g (B, L, H); beta (B, L, H).
|
||||||
|
// The reference transposes to (B, H, L, D) before the loop. We
|
||||||
|
// do the same — it makes per-token indexing trivial.
|
||||||
|
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D_k)
|
||||||
|
let k = k.transpose(1, 2)?.contiguous()?;
|
||||||
|
let v = v.transpose(1, 2)?.contiguous()?; // (B, H, L, D_v)
|
||||||
|
let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
||||||
|
let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
||||||
|
|
||||||
|
// Pre-scale q by 1/sqrt(D_k) once. Everything goes to f32 here
|
||||||
|
// since the delta rule mixes broadcast_mul ops that candle won't
|
||||||
|
// accept across mixed dtypes. On the cuda gating path both beta
|
||||||
|
// and g come back in model dtype; on the cpu path g is already
|
||||||
|
// f32 — both casts are cheap idempotent ops.
|
||||||
|
let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt();
|
||||||
|
let q = (q.to_dtype(candle_core::DType::F32)? * scale)?;
|
||||||
|
let k = k.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let v = v.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let g = g.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let beta = beta.to_dtype(candle_core::DType::F32)?;
|
||||||
|
|
||||||
|
// Initialise the recurrent state from cache or zeros.
|
||||||
|
let state_init = match self.state.recurrent_state.take() {
|
||||||
|
Some(s) => s.to_dtype(candle_core::DType::F32)?,
|
||||||
|
None => Tensor::zeros(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
self.num_v_heads,
|
||||||
|
self.head_k_dim,
|
||||||
|
self.head_v_dim,
|
||||||
|
),
|
||||||
|
candle_core::DType::F32,
|
||||||
|
&device,
|
||||||
|
)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
// The delta-rule body: cuda-accelerated `gated_delta_rule_recurrence`
|
||||||
|
// kernel when we have a cuda device + the kernels are linked in,
|
||||||
|
// pure-Rust per-token fallback otherwise.
|
||||||
|
let (core_attn_out, new_state) = run_delta_rule(
|
||||||
|
&q,
|
||||||
|
&k,
|
||||||
|
&v,
|
||||||
|
&g,
|
||||||
|
&beta,
|
||||||
|
state_init,
|
||||||
|
batch_size,
|
||||||
|
self.num_v_heads,
|
||||||
|
seq_len,
|
||||||
|
self.head_k_dim,
|
||||||
|
self.head_v_dim,
|
||||||
|
)?;
|
||||||
|
// Stash the updated recurrent state for the next call.
|
||||||
|
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
|
||||||
|
|
||||||
|
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
|
||||||
|
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
|
||||||
|
let core_attn_out = core_attn_out.to_dtype(dtype)?;
|
||||||
|
let core_attn_flat =
|
||||||
|
core_attn_out.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
|
||||||
|
let z_flat = z.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
|
||||||
|
|
||||||
|
// RMSNormGated: (out * silu(z) * weight) with the norm.
|
||||||
|
let normed = self.norm.forward(&core_attn_flat, &z_flat)?;
|
||||||
|
let normed = normed.reshape((batch_size, seq_len, self.num_v_heads * self.head_v_dim))?;
|
||||||
|
|
||||||
|
// Output projection: (B, L, value_dim) → (B, L, hidden_size).
|
||||||
|
self.out_proj.forward(&normed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the per-token delta-rule recurrence.
|
||||||
|
///
|
||||||
|
/// `q`, `k`: `(B, H, L, D_k)` (F32). `v`: `(B, H, L, D_v)`. `g`,
|
||||||
|
/// `beta`: `(B, H, L)`. `state`: `(B, H, D_k, D_v)`.
|
||||||
|
///
|
||||||
|
/// Returns `(core_attn_out: (B, H, L, D_v), state: (B, H, D_k, D_v))`,
|
||||||
|
/// both F32. Caller is responsible for cast back to model dtype.
|
||||||
|
///
|
||||||
|
/// Cuda path: dispatches to the `gated_delta_rule_recurrence` kernel
|
||||||
|
/// ported from `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/gdn.cu`.
|
||||||
|
/// All five inputs must be cuda f32 tensors. The kernel is V-tiled
|
||||||
|
/// with compile-time BK; one block per (V-tile, batch*head) and one
|
||||||
|
/// thread per V-column. Each thread holds BK state floats in
|
||||||
|
/// registers — eliminates the launch-overhead floor we hit with
|
||||||
|
/// candle's per-op dispatch (was ~12s/token on Qwen3.6-27B).
|
||||||
|
///
|
||||||
|
/// CPU path: pure-Rust per-token loop. Correct, slow.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn run_delta_rule(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: Tensor,
|
||||||
|
batch_size: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
head_k_dim: usize,
|
||||||
|
head_v_dim: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
// Only dispatch to the kernel if the inputs are on a CUDA
|
||||||
|
// device — CPU tests fall back to the Rust loop below.
|
||||||
|
if q.device().is_cuda() {
|
||||||
|
return run_delta_rule_cuda(
|
||||||
|
q, k, v, g, beta, state, batch_size, num_heads, seq_len, head_k_dim, head_v_dim,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = (batch_size, num_heads, head_k_dim, head_v_dim);
|
||||||
|
run_delta_rule_rust(q, k, v, g, beta, state, seq_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA path. Flattens (B, H, ...) → (BH, ...) at the kernel boundary
|
||||||
|
/// (the kernel uses BH = batch*heads as its outer batch axis) and
|
||||||
|
/// reshapes the kernel's outputs back to (B, H, ...) for the caller.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_delta_rule_cuda(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: Tensor,
|
||||||
|
batch_size: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
head_k_dim: usize,
|
||||||
|
head_v_dim: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let q_bh = q.flatten(0, 1)?.contiguous()?;
|
||||||
|
let k_bh = k.flatten(0, 1)?.contiguous()?;
|
||||||
|
let v_bh = v.flatten(0, 1)?.contiguous()?;
|
||||||
|
let g_bh = g.flatten(0, 1)?.contiguous()?;
|
||||||
|
let beta_bh = beta.flatten(0, 1)?.contiguous()?;
|
||||||
|
let mut state_bh = state.flatten(0, 1)?.contiguous()?;
|
||||||
|
// For long prefills, the chunked kernel (BT=64) processes a chunk
|
||||||
|
// of tokens at a time instead of one-by-one — same delta-rule math,
|
||||||
|
// far fewer block launches. Threshold matches mistralrs.
|
||||||
|
const CHUNK_THRESHOLD: usize = 64;
|
||||||
|
let output_bh = if seq_len >= CHUNK_THRESHOLD {
|
||||||
|
crate::cuda::gdn::chunked_gated_delta_rule_recurrence_cuda(
|
||||||
|
&q_bh,
|
||||||
|
&k_bh,
|
||||||
|
&v_bh,
|
||||||
|
&g_bh,
|
||||||
|
&beta_bh,
|
||||||
|
&mut state_bh,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
|
||||||
|
&q_bh,
|
||||||
|
&k_bh,
|
||||||
|
&v_bh,
|
||||||
|
&g_bh,
|
||||||
|
&beta_bh,
|
||||||
|
&mut state_bh,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
|
||||||
|
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
|
||||||
|
Ok((core_attn_out, new_state))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_delta_rule_rust(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
mut state: Tensor,
|
||||||
|
seq_len: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
use candle_core::IndexOp;
|
||||||
|
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
|
||||||
|
for t in 0..seq_len {
|
||||||
|
let q_t = q.i((.., .., t, ..))?;
|
||||||
|
let k_t = k.i((.., .., t, ..))?;
|
||||||
|
let v_t = v.i((.., .., t, ..))?;
|
||||||
|
let g_t = g.i((.., .., t))?;
|
||||||
|
let beta_t = beta.i((.., .., t))?;
|
||||||
|
let decay = g_t
|
||||||
|
.exp()?
|
||||||
|
.unsqueeze(candle_core::D::Minus1)?
|
||||||
|
.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
state = state.broadcast_mul(&decay)?;
|
||||||
|
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
|
||||||
|
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
|
||||||
|
let delta_row = delta.unsqueeze(2)?;
|
||||||
|
let outer = k_col.broadcast_mul(&delta_row)?;
|
||||||
|
state = (state + outer)?;
|
||||||
|
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
|
||||||
|
outputs.push(out_t.unsqueeze(2)?);
|
||||||
|
}
|
||||||
|
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
|
||||||
|
Ok((core_attn_out, state))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Depthwise causal conv1d + SiLU, with rolling `conv_state`.
|
||||||
|
///
|
||||||
|
/// `x`: `(B, conv_dim, L)` model dtype (f16/bf16 on cuda, anything on cpu).
|
||||||
|
/// `weight`: `(conv_dim, 1, kernel_size)` model dtype.
|
||||||
|
/// `conv_state`: `Some((B, conv_dim, kernel_size))` for decode continuation,
|
||||||
|
/// or `None` for fresh prefill.
|
||||||
|
///
|
||||||
|
/// Returns `(conv_out: (B, conv_dim, L), new_conv_state: (B, conv_dim, kernel_size))`.
|
||||||
|
/// SiLU is baked in.
|
||||||
|
///
|
||||||
|
/// Cuda path: dispatches to `causal_conv1d_update` (decode, seq_len=1 with
|
||||||
|
/// existing state) or `causal_conv1d_full` (prefill / first call), both
|
||||||
|
/// ported from mistralrs `gdn.cu`. Each kernel fuses the depthwise conv
|
||||||
|
/// and SiLU activation in one launch — that's ~4× fewer cuda launches per
|
||||||
|
/// linear-attention layer than the candle `conv1d` + `silu` combo.
|
||||||
|
///
|
||||||
|
/// CPU path: the original prepend-narrow-conv1d-silu sequence.
|
||||||
|
pub(crate) fn run_causal_conv1d(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: Option<Tensor>,
|
||||||
|
batch_size: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
if x.device().is_cuda() {
|
||||||
|
return run_causal_conv1d_cuda(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
conv_state,
|
||||||
|
batch_size,
|
||||||
|
conv_dim,
|
||||||
|
seq_len,
|
||||||
|
conv_kernel_size,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run_causal_conv1d_rust(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
conv_state,
|
||||||
|
batch_size,
|
||||||
|
conv_dim,
|
||||||
|
seq_len,
|
||||||
|
conv_kernel_size,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn run_causal_conv1d_cuda(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: Option<Tensor>,
|
||||||
|
batch_size: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
// Kernel expects weight as (conv_dim, kernel_size) — squeeze the
|
||||||
|
// depthwise channel-multiplier dim.
|
||||||
|
let w = weight.squeeze(1)?.to_dtype(x.dtype())?.contiguous()?;
|
||||||
|
|
||||||
|
// Decode path: seq_len == 1 AND we have an existing conv_state.
|
||||||
|
// Otherwise (prefill or fresh-start decode), use the full path which
|
||||||
|
// zero-pads on the left internally.
|
||||||
|
if let Some(cs) = conv_state
|
||||||
|
&& seq_len == 1
|
||||||
|
{
|
||||||
|
let cs = cs.contiguous()?;
|
||||||
|
let (output, new_conv_state) =
|
||||||
|
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &cs, conv_kernel_size, true)?;
|
||||||
|
return Ok((output, new_conv_state));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefill / fresh-start: the kernel ignores any prior conv_state and
|
||||||
|
// zero-pads. If we had a non-zero prior state and >1 input tokens
|
||||||
|
// (multi-turn continuation), we'd need to fall back to Rust. Match
|
||||||
|
// mistralrs's behaviour: fresh prefill always.
|
||||||
|
let device = x.device().clone();
|
||||||
|
let zeros_cs = Tensor::zeros((batch_size, conv_dim, conv_kernel_size), x.dtype(), &device)?;
|
||||||
|
let (output, new_conv_state) =
|
||||||
|
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &zeros_cs, conv_kernel_size, false)?;
|
||||||
|
Ok((output, new_conv_state))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fused GDN gating: computes `beta = sigmoid(b)` and
|
||||||
|
/// `g = -exp(a_log) * softplus(a + dt_bias)` together.
|
||||||
|
///
|
||||||
|
/// `b`, `a`: `(B, L, num_heads)` model dtype.
|
||||||
|
/// `a_log`, `dt_bias`: `(num_heads,)` model dtype (cast to f32 internally).
|
||||||
|
///
|
||||||
|
/// Returns `(beta, g)` both in model dtype on the cuda path, both in f32
|
||||||
|
/// on the cpu fallback. The caller casts to f32 before the delta rule.
|
||||||
|
///
|
||||||
|
/// Cuda path: dispatches to `fused_gdn_gating_cuda` — one kernel
|
||||||
|
/// replaces sigmoid + neg(exp) + softplus + broadcast_mul (≈10 candle
|
||||||
|
/// launches per layer).
|
||||||
|
pub(crate) fn run_fused_gating(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
if b.device().is_cuda() {
|
||||||
|
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?.contiguous()?;
|
||||||
|
let dt_bias_f32 = dt_bias.to_dtype(candle_core::DType::F32)?.contiguous()?;
|
||||||
|
return crate::cuda::gdn::fused_gdn_gating_cuda(b, a, &a_log_f32, &dt_bias_f32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run_fused_gating_rust(b, a, a_log, dt_bias)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_fused_gating_rust(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let beta = candle_nn::ops::sigmoid(b)?;
|
||||||
|
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let neg_a_exp = a_log_f32.exp()?.neg()?;
|
||||||
|
let dt_b_f32 = dt_bias.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
|
||||||
|
let softplus_val = softplus(&a_plus_dt)?;
|
||||||
|
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
|
||||||
|
let g = neg_a_exp_b.broadcast_mul(&softplus_val)?;
|
||||||
|
Ok((beta, g))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_causal_conv1d_rust(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: Option<Tensor>,
|
||||||
|
batch_size: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let device = x.device().clone();
|
||||||
|
|
||||||
|
let prepended = match &conv_state {
|
||||||
|
Some(prev) => Tensor::cat(&[prev, x], 2)?,
|
||||||
|
None => x.clone(),
|
||||||
|
};
|
||||||
|
let prep_len = prepended.dims()[2];
|
||||||
|
|
||||||
|
let new_state = if prep_len >= conv_kernel_size {
|
||||||
|
prepended.narrow(2, prep_len - conv_kernel_size, conv_kernel_size)?
|
||||||
|
} else {
|
||||||
|
let pad = Tensor::zeros(
|
||||||
|
(batch_size, conv_dim, conv_kernel_size - prep_len),
|
||||||
|
dtype,
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
Tensor::cat(&[&pad, &prepended], 2)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let conv_out = prepended.conv1d(weight, conv_kernel_size - 1, 1, 1, conv_dim)?;
|
||||||
|
let conv_out = conv_out.narrow(2, 0, prep_len)?;
|
||||||
|
let conv_out = candle_nn::ops::silu(&conv_out)?;
|
||||||
|
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
|
||||||
|
Ok((conv_out, new_state))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
|
||||||
|
/// the standard `[out, in]` order.
|
||||||
|
fn load_linear_no_bias(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
name: &str,
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.pp(name)
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Numerically-stable `softplus(x) = ln(1 + exp(x))`. Matches PyTorch's
|
||||||
|
/// `F.softplus` default (beta=1, threshold=20: for large positive x,
|
||||||
|
/// returns x as-is to avoid overflow in the exp).
|
||||||
|
pub(crate) fn softplus(x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let threshold = 20.0_f64;
|
||||||
|
let big = x.ge(threshold)?; // Tensor<u8> mask
|
||||||
|
let safe = x.minimum(&x.affine(0.0, 0.0)?.affine(0.0, threshold)?)?; // min(x, threshold)
|
||||||
|
let small = ((safe.exp()? + 1.0_f64)?).log()?;
|
||||||
|
// Select x where big, else small.
|
||||||
|
big.where_cond(x, &small)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `repeat_interleave` along a single dim. Candle has no built-in for
|
||||||
|
/// this; emulate with unsqueeze + expand + reshape.
|
||||||
|
pub(crate) fn repeat_interleave(
|
||||||
|
x: &Tensor,
|
||||||
|
repeats: usize,
|
||||||
|
dim: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
if repeats == 1 {
|
||||||
|
return Ok(x.clone());
|
||||||
|
}
|
||||||
|
let mut shape = x.dims().to_vec();
|
||||||
|
let orig = shape[dim];
|
||||||
|
shape.insert(dim + 1, repeats);
|
||||||
|
let mut expanded_shape = shape.clone();
|
||||||
|
expanded_shape[dim + 1] = repeats;
|
||||||
|
let x = x.unsqueeze(dim + 1)?;
|
||||||
|
let x = x.expand(expanded_shape)?;
|
||||||
|
let mut out_shape = x.dims().to_vec();
|
||||||
|
out_shape.remove(dim + 1);
|
||||||
|
out_shape[dim] = orig * repeats;
|
||||||
|
x.contiguous()?.reshape(out_shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use candle_core::{DType, Device};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn softplus_small_x() {
|
||||||
|
// softplus(0) = ln(2) ≈ 0.6931
|
||||||
|
let x = Tensor::new(&[0.0_f32], &Device::Cpu).unwrap();
|
||||||
|
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
|
||||||
|
assert!((out[0] - 2.0_f32.ln()).abs() < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn softplus_large_x_returns_x() {
|
||||||
|
// For x = 30, softplus(x) ≈ x (the threshold branch).
|
||||||
|
let x = Tensor::new(&[30.0_f32], &Device::Cpu).unwrap();
|
||||||
|
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
|
||||||
|
assert!((out[0] - 30.0).abs() < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn repeat_interleave_doubles_dim() {
|
||||||
|
let x = Tensor::new(&[[1.0_f32, 2.0], [3.0, 4.0]], &Device::Cpu).unwrap(); // shape (2, 2)
|
||||||
|
let out = repeat_interleave(&x, 2, 1).unwrap(); // each col duplicated
|
||||||
|
let v: Vec<Vec<f32>> = out.to_vec2().unwrap();
|
||||||
|
// Row 0: 1, 1, 2, 2
|
||||||
|
// Row 1: 3, 3, 4, 4
|
||||||
|
assert_eq!(v[0], vec![1.0, 1.0, 2.0, 2.0]);
|
||||||
|
assert_eq!(v[1], vec![3.0, 3.0, 4.0, 4.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sanity: the recurrent path produces a finite tensor of the right
|
||||||
|
/// shape on tiny dimensions. Doesn't validate numerical correctness
|
||||||
|
/// against the Python reference — that would need a fixed-weight
|
||||||
|
/// fixture to compare against. Catches structural mistakes
|
||||||
|
/// (broadcasting shapes, off-by-one slices) early.
|
||||||
|
#[test]
|
||||||
|
fn forward_smoke_with_tiny_dimensions() {
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
let dtype = DType::F32;
|
||||||
|
let (b, l) = (1, 3);
|
||||||
|
let cfg = TextConfig {
|
||||||
|
vocab_size: 100,
|
||||||
|
hidden_size: 16,
|
||||||
|
intermediate_size: 32,
|
||||||
|
num_hidden_layers: 1,
|
||||||
|
num_attention_heads: 4,
|
||||||
|
num_key_value_heads: 1,
|
||||||
|
head_dim: 4,
|
||||||
|
max_position_embeddings: 32,
|
||||||
|
rope_parameters: RopeParameters {
|
||||||
|
rope_theta: 10000.0,
|
||||||
|
partial_rotary_factor: 1.0,
|
||||||
|
rope_type: None,
|
||||||
|
},
|
||||||
|
rms_norm_eps: 1e-6,
|
||||||
|
tie_word_embeddings: false,
|
||||||
|
attn_output_gate: true,
|
||||||
|
layer_types: vec!["linear_attention".into()],
|
||||||
|
full_attention_interval: Some(4),
|
||||||
|
hidden_act: "silu".into(),
|
||||||
|
linear_num_value_heads: 4,
|
||||||
|
linear_num_key_heads: 2,
|
||||||
|
linear_key_head_dim: 4,
|
||||||
|
linear_value_head_dim: 4,
|
||||||
|
linear_conv_kernel_dim: 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build a synthetic VarBuilder with all-zeros weights.
|
||||||
|
// Easier path: skip the load and construct GatedDeltaNet
|
||||||
|
// manually by hand-rolling the Linear/Tensor inputs.
|
||||||
|
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &dev).unwrap();
|
||||||
|
let key_dim = cfg.linear_key_head_dim * cfg.linear_num_key_heads;
|
||||||
|
let value_dim = cfg.linear_value_head_dim * cfg.linear_num_value_heads;
|
||||||
|
let conv_dim = key_dim * 2 + value_dim;
|
||||||
|
let mut net = GatedDeltaNet {
|
||||||
|
in_proj_qkv: Linear::new(zeros(&[conv_dim, cfg.hidden_size]), None),
|
||||||
|
in_proj_z: Linear::new(zeros(&[value_dim, cfg.hidden_size]), None),
|
||||||
|
in_proj_b: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
|
||||||
|
in_proj_a: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
|
||||||
|
out_proj: Linear::new(zeros(&[cfg.hidden_size, value_dim]), None),
|
||||||
|
conv1d_weight: zeros(&[conv_dim, 1, cfg.linear_conv_kernel_dim]),
|
||||||
|
dt_bias: zeros(&[cfg.linear_num_value_heads]),
|
||||||
|
a_log: zeros(&[cfg.linear_num_value_heads]),
|
||||||
|
norm: {
|
||||||
|
let weight = Tensor::ones(&[cfg.linear_value_head_dim], dtype, &dev).unwrap();
|
||||||
|
Qwen3_5RmsNormGated::from_weight(weight, cfg.rms_norm_eps)
|
||||||
|
},
|
||||||
|
num_v_heads: cfg.linear_num_value_heads,
|
||||||
|
num_k_heads: cfg.linear_num_key_heads,
|
||||||
|
head_k_dim: cfg.linear_key_head_dim,
|
||||||
|
head_v_dim: cfg.linear_value_head_dim,
|
||||||
|
key_dim,
|
||||||
|
value_dim,
|
||||||
|
conv_dim,
|
||||||
|
conv_kernel_size: cfg.linear_conv_kernel_dim,
|
||||||
|
state: GatedDeltaNetState::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let x = Tensor::ones(&[b, l, cfg.hidden_size], dtype, &dev).unwrap();
|
||||||
|
let y = net.forward(&x).unwrap();
|
||||||
|
assert_eq!(y.dims(), &[b, l, cfg.hidden_size]);
|
||||||
|
// All zero weights → output should be zero. Confirms no NaN/Inf
|
||||||
|
// poisoning from the f32 promotions.
|
||||||
|
let v: Vec<f32> = y.flatten_all().unwrap().to_vec1().unwrap();
|
||||||
|
assert!(v.iter().all(|x| x.is_finite()));
|
||||||
|
}
|
||||||
|
}
|
||||||
53
crates/neuron/src/harness/arch/qwen3_5/mlp.rs
Normal file
53
crates/neuron/src/harness/arch/qwen3_5/mlp.rs
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
//! SwiGLU MLP block for Qwen3-Next.
|
||||||
|
//!
|
||||||
|
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
|
||||||
|
//! no bias on any of the three projections.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
|
||||||
|
pub struct Qwen3_5MLP {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5MLP {
|
||||||
|
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let h = cfg.hidden_size;
|
||||||
|
let i = cfg.intermediate_size;
|
||||||
|
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
|
||||||
|
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
|
||||||
|
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj,
|
||||||
|
up_proj,
|
||||||
|
down_proj,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3_5MLP {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
|
||||||
|
let rhs = self.up_proj.forward(x)?;
|
||||||
|
self.down_proj.forward(&(lhs * rhs)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_linear_no_bias(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
name: &str,
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.pp(name)
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
397
crates/neuron/src/harness/arch/qwen3_5/mod.rs
Normal file
397
crates/neuron/src/harness/arch/qwen3_5/mod.rs
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
//! Qwen3-Next (`model_type = "qwen3_5"`) architecture — Qwen3.6's
|
||||||
|
//! upstream architecture revision.
|
||||||
|
//!
|
||||||
|
//! ## Naming
|
||||||
|
//!
|
||||||
|
//! The model release this targets is `Qwen/Qwen3.6-*` but the
|
||||||
|
//! architecture name in HuggingFace's `config.json` is `qwen3_5`.
|
||||||
|
//! mistralrs calls the same architecture `qwen3_next`; that label
|
||||||
|
//! ages poorly the next time Qwen ship a new arch, so we key on the
|
||||||
|
//! canonical `qwen3_5` from the model's own config.
|
||||||
|
//!
|
||||||
|
//! ## Status
|
||||||
|
//!
|
||||||
|
//! **Single-GPU dense path is real**. Both attention flavours
|
||||||
|
//! (`full_attention` with the output-gated GQA causal attention and
|
||||||
|
//! `linear_attention` with the Gated DeltaNet recurrent block) are
|
||||||
|
//! implemented. The model loads from upstream safetensors via the
|
||||||
|
//! existing `load_arch_dense` dispatch and runs forward end to end.
|
||||||
|
//!
|
||||||
|
//! Numerical correctness vs the reference Python is **not yet
|
||||||
|
//! validated** — the structural code path is right, weight tensor
|
||||||
|
//! names match the upstream layout, shapes flow through cleanly, but
|
||||||
|
//! the Tbilisi probe (and any other downstream test) is the next
|
||||||
|
//! step. Likely places a bug would surface:
|
||||||
|
//! - Per-rank vs per-token-position offsets in the recurrent delta
|
||||||
|
//! rule (`linear_attn.rs`).
|
||||||
|
//! - Off-by-one in the conv state continuation across decode steps.
|
||||||
|
//! - RoPE phase mismatch from MRoPE simplification (we treat the
|
||||||
|
//! three position grids as collapsed, which is correct only for
|
||||||
|
//! text-only inference).
|
||||||
|
//!
|
||||||
|
//! ## Submodules
|
||||||
|
//!
|
||||||
|
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
|
||||||
|
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
|
||||||
|
//! `l2norm` helper.
|
||||||
|
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
|
||||||
|
//! rotate-half).
|
||||||
|
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
|
||||||
|
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
|
||||||
|
//! widening on `q_proj`.
|
||||||
|
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
|
||||||
|
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
|
||||||
|
//! delta rule → RMSNormGated → out_proj).
|
||||||
|
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
|
||||||
|
//! two attention flavours per layer index.
|
||||||
|
//!
|
||||||
|
//! ## Open work
|
||||||
|
//!
|
||||||
|
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
|
||||||
|
//! Sharding strategy diverges by layer type:
|
||||||
|
//! - Full-attention layers: column-parallel q/k/v (including the
|
||||||
|
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
|
||||||
|
//! `tp_qwen3.rs`.
|
||||||
|
//! - Linear-attention layers: the recurrent state is per-V-head, so
|
||||||
|
//! V-head-dimension sharding works cleanly — split `num_v_heads`
|
||||||
|
//! across ranks (`num_v_heads / world_size` per rank), shard
|
||||||
|
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
|
||||||
|
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
|
||||||
|
//! `dt_bias` per-head params shard with the heads.
|
||||||
|
//!
|
||||||
|
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
|
||||||
|
//! per-token recurrent path for prefill too — correct but O(L).
|
||||||
|
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
|
||||||
|
//! prefill substantially with no surface change.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||||
|
use candle_nn::Embedding;
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub mod decoder;
|
||||||
|
pub mod full_attn;
|
||||||
|
pub mod linear_attn;
|
||||||
|
pub mod mlp;
|
||||||
|
pub mod rmsnorm;
|
||||||
|
pub mod rope;
|
||||||
|
|
||||||
|
use decoder::Qwen3_5DecoderLayer;
|
||||||
|
use rmsnorm::Qwen3_5RmsNorm;
|
||||||
|
use rope::RotaryEmbedding;
|
||||||
|
|
||||||
|
/// `model_type` we deserialise from `config.json`. Const so the
|
||||||
|
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
|
||||||
|
/// magic strings.
|
||||||
|
pub const MODEL_TYPE: &str = "qwen3_5";
|
||||||
|
|
||||||
|
/// Top-level shape of Qwen3-Next's `config.json`. The real
|
||||||
|
/// hyperparameters live in `text_config`; the rest is multimodal /
|
||||||
|
/// tokeniser glue we don't need for the language-model forward.
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
/// Always `"qwen3_5"` for this architecture. Kept on the struct
|
||||||
|
/// so the (eventual) dispatch / logging code can show it without
|
||||||
|
/// re-parsing the JSON.
|
||||||
|
pub model_type: String,
|
||||||
|
/// The text-side hyperparameters. Everything we actually need.
|
||||||
|
pub text_config: TextConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
||||||
|
/// but with the extras Qwen3-Next adds (`attn_output_gate`,
|
||||||
|
/// `layer_types`, `full_attention_interval`, larger `head_dim`).
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct TextConfig {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub head_dim: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
/// Nested RoPE settings. Qwen3-Next puts `rope_theta` and
|
||||||
|
/// `partial_rotary_factor` inside this block rather than at the
|
||||||
|
/// top level — important because the partial rotary means only
|
||||||
|
/// `head_dim * partial_rotary_factor` dims get RoPE applied (the
|
||||||
|
/// rest pass through unchanged).
|
||||||
|
pub rope_parameters: RopeParameters,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
|
|
||||||
|
/// New in Qwen3-Next: a sigmoid gate multiplied into the attention
|
||||||
|
/// output before the o_proj. The Python reference applies it
|
||||||
|
/// pointwise after softmax+matmul.
|
||||||
|
#[serde(default)]
|
||||||
|
pub attn_output_gate: bool,
|
||||||
|
|
||||||
|
/// One entry per decoder layer; values are `"full_attention"` or
|
||||||
|
/// `"linear_attention"`. Length must equal `num_hidden_layers`.
|
||||||
|
/// `full_attention_interval` is a derived hint (every 4th layer
|
||||||
|
/// by default) — `layer_types` is authoritative.
|
||||||
|
#[serde(default)]
|
||||||
|
pub layer_types: Vec<String>,
|
||||||
|
|
||||||
|
/// Hint for the layer-type pattern (defaults to 4). Kept for
|
||||||
|
/// logging / validation; the forward dispatches on `layer_types`.
|
||||||
|
#[serde(default)]
|
||||||
|
pub full_attention_interval: Option<usize>,
|
||||||
|
|
||||||
|
/// Hidden activation (`"silu"` for Qwen3-Next). Used by the MLP
|
||||||
|
/// and the linear-attention conv1d.
|
||||||
|
#[serde(default = "default_hidden_act")]
|
||||||
|
pub hidden_act: String,
|
||||||
|
|
||||||
|
// --- Gated DeltaNet (linear-attention) hyperparams -----------------
|
||||||
|
/// Per-layer linear-attention V-head count (Qwen3.6-27B: 48).
|
||||||
|
/// More V-heads than K-heads is fine — query/key get
|
||||||
|
/// `repeat_interleave`'d to match before the delta rule.
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_num_value_heads: usize,
|
||||||
|
/// Per-layer linear-attention K-head count (Qwen3.6-27B: 16).
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_num_key_heads: usize,
|
||||||
|
/// Per-head key dimension for the linear-attention path
|
||||||
|
/// (Qwen3.6-27B: 128). Separate from `head_dim` which the
|
||||||
|
/// full-attention layers use.
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_key_head_dim: usize,
|
||||||
|
/// Per-head value dimension for the linear-attention path
|
||||||
|
/// (Qwen3.6-27B: 128).
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_value_head_dim: usize,
|
||||||
|
/// Causal Conv1d kernel size used before the delta rule
|
||||||
|
/// (Qwen3.6-27B: 4).
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_conv_kernel_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_hidden_act() -> String {
|
||||||
|
"silu".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
|
||||||
|
/// `mrope_section` and `mrope_interleaved` are accepted via the
|
||||||
|
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat
|
||||||
|
/// MRoPE as plain RoPE for text-only inference (the three position
|
||||||
|
/// grids carry identical ids when there's no vision input, so the
|
||||||
|
/// interleaving is a no-op).
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct RopeParameters {
|
||||||
|
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
|
||||||
|
#[serde(default = "default_rope_theta")]
|
||||||
|
pub rope_theta: f64,
|
||||||
|
/// Fraction of `head_dim` that gets the rotation applied. The
|
||||||
|
/// remaining `head_dim * (1 - partial_rotary_factor)` dims pass
|
||||||
|
/// through unchanged. Qwen3.6 / Qwen3.5: 0.25.
|
||||||
|
#[serde(default = "default_partial_rotary_factor")]
|
||||||
|
pub partial_rotary_factor: f32,
|
||||||
|
/// `"default"` for the standard inv_freq RoPE; other values (e.g.
|
||||||
|
/// `"linear"`, `"dynamic"`) are upstream-supported but not yet
|
||||||
|
/// implemented here.
|
||||||
|
#[serde(default)]
|
||||||
|
pub rope_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_rope_theta() -> f64 {
|
||||||
|
10_000.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_partial_rotary_factor() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Qwen3-Next base transformer (embedding + decoder stack + final
|
||||||
|
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
|
||||||
|
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
|
||||||
|
/// loaded handle.
|
||||||
|
pub struct Qwen3_5Model {
|
||||||
|
embed_tokens: Embedding,
|
||||||
|
layers: Vec<Qwen3_5DecoderLayer>,
|
||||||
|
norm: Qwen3_5RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5Model {
|
||||||
|
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
|
||||||
|
// Qwen3-Next is a multimodal architecture whose text core lives
|
||||||
|
// under `model.language_model.*` — sibling to `model.visual.*`
|
||||||
|
// (the vision tower) and to top-level `lm_head` / `mtp.*`.
|
||||||
|
// Every text-side tensor in the safetensors files is under
|
||||||
|
// this prefix; we ignore the vision and MTP weights for
|
||||||
|
// language-model inference.
|
||||||
|
let text_vb = vb.pp("model.language_model");
|
||||||
|
|
||||||
|
let embed_vb = text_vb.pp("embed_tokens");
|
||||||
|
let embed_weight = embed_vb
|
||||||
|
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||||
|
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
|
||||||
|
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||||
|
|
||||||
|
if cfg.layer_types.len() != cfg.num_hidden_layers {
|
||||||
|
anyhow::bail!(
|
||||||
|
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
|
||||||
|
got {}",
|
||||||
|
cfg.num_hidden_layers,
|
||||||
|
cfg.layer_types.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let vb_l = text_vb.pp("layers");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(Qwen3_5DecoderLayer::load(
|
||||||
|
cfg,
|
||||||
|
rotary.clone(),
|
||||||
|
i,
|
||||||
|
&vb_l.pp(i),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_weight(&self) -> &Tensor {
|
||||||
|
self.embed_tokens.embeddings()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for l in &mut self.layers {
|
||||||
|
l.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
// Causal mask only needed for L > 1 prefill; full-attention
|
||||||
|
// layers consume it via broadcast_add. Linear-attention layers
|
||||||
|
// ignore the mask.
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
|
};
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Qwen3_5ForCausalLM {
|
||||||
|
base: Qwen3_5Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5ForCausalLM {
|
||||||
|
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let cfg = &config.text_config;
|
||||||
|
let base = Qwen3_5Model::load(cfg, &vb)?;
|
||||||
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::new(base.embed_weight().clone(), None)
|
||||||
|
} else {
|
||||||
|
let weight = vb
|
||||||
|
.pp("lm_head")
|
||||||
|
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||||
|
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
||||||
|
Linear::new(weight, None)
|
||||||
|
};
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
||||||
|
/// the last position, shape `(B, 1, vocab_size)` — same contract
|
||||||
|
/// as `qwen3::ModelForCausalLM::forward` so the harness's
|
||||||
|
/// `squeeze_to_vocab` helper handles both uniformly.
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
let hidden = self.base.forward(input, offset)?;
|
||||||
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// Confirms we can deserialise the real upstream config shape.
|
||||||
|
/// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to
|
||||||
|
/// the fields the architecture cares about. Note `rope_theta` and
|
||||||
|
/// `partial_rotary_factor` are nested under `rope_parameters` —
|
||||||
|
/// Qwen3-Next does NOT have a top-level `rope_theta`.
|
||||||
|
#[test]
|
||||||
|
fn config_deserialises_the_real_qwen3_6_shape() {
|
||||||
|
let raw = r#"{
|
||||||
|
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||||
|
"model_type": "qwen3_5",
|
||||||
|
"image_token_id": 248056,
|
||||||
|
"language_model_only": false,
|
||||||
|
"text_config": {
|
||||||
|
"vocab_size": 248064,
|
||||||
|
"hidden_size": 5120,
|
||||||
|
"intermediate_size": 17408,
|
||||||
|
"num_hidden_layers": 64,
|
||||||
|
"num_attention_heads": 64,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 256,
|
||||||
|
"max_position_embeddings": 32768,
|
||||||
|
"rope_parameters": {
|
||||||
|
"mrope_interleaved": true,
|
||||||
|
"mrope_section": [11, 11, 10],
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"rope_theta": 10000000,
|
||||||
|
"rope_type": "default"
|
||||||
|
},
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"attn_output_gate": true,
|
||||||
|
"full_attention_interval": 4,
|
||||||
|
"layer_types": [
|
||||||
|
"linear_attention", "linear_attention",
|
||||||
|
"linear_attention", "full_attention"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
let cfg: Config = serde_json::from_str(raw).expect("parse Qwen3.6 config");
|
||||||
|
assert_eq!(cfg.model_type, "qwen3_5");
|
||||||
|
assert_eq!(cfg.text_config.hidden_size, 5120);
|
||||||
|
assert_eq!(cfg.text_config.head_dim, 256);
|
||||||
|
assert!(cfg.text_config.attn_output_gate);
|
||||||
|
assert_eq!(cfg.text_config.full_attention_interval, Some(4));
|
||||||
|
assert_eq!(cfg.text_config.layer_types.len(), 4);
|
||||||
|
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
|
||||||
|
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
161
crates/neuron/src/harness/arch/qwen3_5/rmsnorm.rs
Normal file
161
crates/neuron/src/harness/arch/qwen3_5/rmsnorm.rs
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
//! Norm primitives for Qwen3-Next.
|
||||||
|
//!
|
||||||
|
//! Two reasons we can't reuse `candle_nn::RmsNorm` directly:
|
||||||
|
//!
|
||||||
|
//! 1. **`(1.0 + weight)` scaling.** Qwen3-Next's `Qwen3_5RMSNorm`
|
||||||
|
//! initialises `weight` to zeros and applies `(1.0 + weight)` to
|
||||||
|
//! the normalised vector. `candle_nn::RmsNorm` applies `weight`
|
||||||
|
//! directly. The two are equivalent only when the operator has
|
||||||
|
//! pre-shifted the weights — the upstream checkpoints have not. See
|
||||||
|
//! `huggingface/transformers#29402` for the upstream PR that
|
||||||
|
//! introduced the `(1 + w)` form to recover from the zero-init.
|
||||||
|
//!
|
||||||
|
//! 2. **Gated variant.** The linear-attention layer post-normalises
|
||||||
|
//! its output by an RMSNorm *gated* with a per-element SiLU on
|
||||||
|
//! a sibling `z` projection — fused for numerical reasons (the
|
||||||
|
//! norm's float32 promotion has to happen before the SiLU
|
||||||
|
//! multiply). Not a single existing candle op.
|
||||||
|
//!
|
||||||
|
//! Both ops accept inputs in any compute dtype; promotion to f32 for
|
||||||
|
//! the variance calculation matches the Python reference.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{D, Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
|
||||||
|
/// L2-normalise along the last dim with a small epsilon. Matches the
|
||||||
|
/// `l2norm` helper in `transformers/models/qwen3_5/modeling_qwen3_5.py`
|
||||||
|
/// — `x * rsqrt(sum(x*x) + eps)`. The linear-attention path uses this
|
||||||
|
/// on Q and K before the delta rule when
|
||||||
|
/// `use_qk_l2norm_in_kernel=True` (which Qwen3-Next always sets).
|
||||||
|
pub fn l2norm(x: &Tensor, eps: f32) -> candle_core::Result<Tensor> {
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let sq = x_f32.sqr()?;
|
||||||
|
let sum = sq.sum_keepdim(D::Minus1)?;
|
||||||
|
let inv = (sum + eps as f64)?.sqrt()?.recip()?;
|
||||||
|
x_f32.broadcast_mul(&inv)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Qwen3-Next's RMSNorm. Stores the raw weight tensor; forward applies
|
||||||
|
/// `(1.0 + weight) * x_normed`.
|
||||||
|
pub struct Qwen3_5RmsNorm {
|
||||||
|
weight: Tensor,
|
||||||
|
eps: f32,
|
||||||
|
size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5RmsNorm {
|
||||||
|
/// Load `weight` from the ShardedVarBuilder. `vb` should already be
|
||||||
|
/// `.pp(...)`-ed to the norm's tensor prefix.
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get(size, "weight")
|
||||||
|
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
eps: eps as f32,
|
||||||
|
size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn size(&self) -> usize {
|
||||||
|
self.size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3_5RmsNorm {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||||
|
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
|
||||||
|
// Promote weight to f32 and shift by 1.0 *before* multiplying.
|
||||||
|
// Doing the (1 + w) operation in fp16 lands at -inf for the
|
||||||
|
// bottom-of-range weights at load time.
|
||||||
|
let w_f32 = self.weight.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let scale = (w_f32 + 1.0_f64)?;
|
||||||
|
normed.broadcast_mul(&scale)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gated RMSNorm used at the tail of `Qwen3_5GatedDeltaNet`. Equivalent
|
||||||
|
/// to `x_normed * weight * silu(gate)` but with both the norm and the
|
||||||
|
/// gate evaluated in float32 to avoid mid-pipeline underflow.
|
||||||
|
///
|
||||||
|
/// Note: unlike `Qwen3_5RmsNorm`, this variant matches the Python
|
||||||
|
/// reference's `Qwen3_5RMSNormGated` which uses `weight` directly (not
|
||||||
|
/// `1.0 + weight`).
|
||||||
|
pub struct Qwen3_5RmsNormGated {
|
||||||
|
weight: Tensor,
|
||||||
|
eps: f32,
|
||||||
|
size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5RmsNormGated {
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get(size, "weight")
|
||||||
|
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
eps: eps as f32,
|
||||||
|
size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Direct constructor — used by unit tests that build a layer
|
||||||
|
/// without going through a VarBuilder.
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
|
||||||
|
let size = weight.dims()[0];
|
||||||
|
Self {
|
||||||
|
weight,
|
||||||
|
eps: eps as f32,
|
||||||
|
size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn size(&self) -> usize {
|
||||||
|
self.size
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `x` and `gate` share the same last-dim shape (`size`).
|
||||||
|
pub fn forward(&self, x: &Tensor, gate: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||||
|
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
|
||||||
|
let w = self.weight.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let out = normed.broadcast_mul(&w)?;
|
||||||
|
// SiLU on the float32 gate, multiply back into the normed
|
||||||
|
// tensor, then cast to the model dtype.
|
||||||
|
let g = gate.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let silu_gate = candle_nn::ops::silu(&g)?;
|
||||||
|
(out * silu_gate)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use candle_core::Device;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn l2norm_matches_hand_calc() {
|
||||||
|
let x = Tensor::new(&[3.0_f32, 4.0_f32], &Device::Cpu).unwrap();
|
||||||
|
let out = l2norm(&x, 1e-6).unwrap();
|
||||||
|
let v: Vec<f32> = out.to_vec1().unwrap();
|
||||||
|
// |x| = 5, so x/|x| = [0.6, 0.8] (eps is tiny).
|
||||||
|
assert!((v[0] - 0.6).abs() < 1e-4);
|
||||||
|
assert!((v[1] - 0.8).abs() < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn l2norm_zero_vector_is_safe_via_epsilon() {
|
||||||
|
let x = Tensor::new(&[0.0_f32, 0.0_f32], &Device::Cpu).unwrap();
|
||||||
|
let out = l2norm(&x, 1e-6).unwrap();
|
||||||
|
let v: Vec<f32> = out.to_vec1().unwrap();
|
||||||
|
assert!(v.iter().all(|x| x.is_finite()));
|
||||||
|
}
|
||||||
|
}
|
||||||
114
crates/neuron/src/harness/arch/qwen3_5/rope.rs
Normal file
114
crates/neuron/src/harness/arch/qwen3_5/rope.rs
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
||||||
|
//!
|
||||||
|
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
|
||||||
|
//! reference Python — three position grids interleaved per
|
||||||
|
//! `mrope_section`. For text-only inference all three grids carry the
|
||||||
|
//! same position ids and the interleave is a no-op, so this module
|
||||||
|
//! implements the plain (non-mrope) flavour: the standard inv_freq
|
||||||
|
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
|
||||||
|
//!
|
||||||
|
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
|
||||||
|
//! head dim is negated and swapped into the first). The reference
|
||||||
|
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
|
||||||
|
//! `rope_slow` is the matching helper.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
/// Number of dims at the head's leading edge that the rotation
|
||||||
|
/// covers. The remaining `head_dim - rotary_dim` dims pass through
|
||||||
|
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
|
||||||
|
/// for `head_dim = 256` only 64 dims rotate.
|
||||||
|
rotary_dim: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let rope = &cfg.rope_parameters;
|
||||||
|
let rotary_dim = (head_dim as f32 * rope.partial_rotary_factor) as usize;
|
||||||
|
if !rotary_dim.is_multiple_of(2) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"rotary_dim = head_dim * partial_rotary_factor = {head_dim} * {} = {rotary_dim} \
|
||||||
|
must be even (cos/sin are paired)",
|
||||||
|
rope.partial_rotary_factor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if rotary_dim == 0 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"rotary_dim = 0 (partial_rotary_factor = {} too small)",
|
||||||
|
rope.partial_rotary_factor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<f32> = (0..rotary_dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let n = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
|
rotary_dim,
|
||||||
|
head_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply RoPE to q, k.
|
||||||
|
///
|
||||||
|
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
|
||||||
|
/// into the cached cos/sin table — the position of the first token
|
||||||
|
/// in the current step.
|
||||||
|
///
|
||||||
|
/// When `rotary_dim < head_dim` the rotation is applied only to the
|
||||||
|
/// first `rotary_dim` dims of each head; the tail passes through
|
||||||
|
/// unchanged (matches the reference Python's
|
||||||
|
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`).
|
||||||
|
pub fn apply(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let (_, _, seq_len, head_dim_in) = q.dims4()?;
|
||||||
|
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
|
||||||
|
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||||
|
if self.rotary_dim == self.head_dim {
|
||||||
|
// Full rotation.
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
} else {
|
||||||
|
// Partial rotation: narrow → rotate → cat the untouched tail.
|
||||||
|
let tail = self.head_dim - self.rotary_dim;
|
||||||
|
let q_rot = q
|
||||||
|
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||||
|
.contiguous()?;
|
||||||
|
let q_pass = q.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||||
|
let k_rot = k
|
||||||
|
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||||
|
.contiguous()?;
|
||||||
|
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||||
|
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?;
|
||||||
|
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?;
|
||||||
|
let q_embed =
|
||||||
|
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||||
|
let k_embed =
|
||||||
|
Tensor::cat(&[&k_rotated, &k_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1928
crates/neuron/src/harness/candle.rs
Normal file
1928
crates/neuron/src/harness/candle.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
|||||||
// llama.cpp harness implementation — Phase 11.
|
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
//! mistral.rs harness implementation.
|
|
||||||
//!
|
|
||||||
//! Wraps the mistral.rs HTTP API for model lifecycle management
|
|
||||||
//! and optionally manages the process via systemd.
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec};
|
|
||||||
use reqwest::Client;
|
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
pub struct MistralRsHarness {
|
|
||||||
endpoint: String,
|
|
||||||
systemd_unit: Option<String>,
|
|
||||||
client: Client,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MistralRsHarness {
|
|
||||||
pub fn new(endpoint: String, systemd_unit: Option<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
endpoint,
|
|
||||||
systemd_unit,
|
|
||||||
client: Client::builder()
|
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
|
||||||
.build()
|
|
||||||
.expect("failed to build HTTP client"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Response from mistral.rs `GET /v1/models`.
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ModelsResponse {
|
|
||||||
data: Vec<ModelEntry>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ModelEntry {
|
|
||||||
id: String,
|
|
||||||
#[serde(default)]
|
|
||||||
status: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Harness for MistralRsHarness {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"mistralrs"
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
|
||||||
let Some(unit) = &self.systemd_unit else {
|
|
||||||
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
|
||||||
};
|
|
||||||
|
|
||||||
let output = tokio::process::Command::new("systemctl")
|
|
||||||
.args(["start", unit])
|
|
||||||
.output()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !output.status.success() {
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
||||||
anyhow::bail!("systemctl start {unit} failed: {stderr}");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the health endpoint to respond (up to 30s).
|
|
||||||
let url = format!("{}/health", self.endpoint);
|
|
||||||
for _ in 0..30 {
|
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
|
||||||
if self.client.get(&url).send().await.is_ok() {
|
|
||||||
tracing::info!(unit, "mistralrs started and healthy");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
anyhow::bail!("mistralrs started but health endpoint did not respond within 30s");
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn stop(&self) -> Result<()> {
|
|
||||||
let Some(unit) = &self.systemd_unit else {
|
|
||||||
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
|
||||||
};
|
|
||||||
|
|
||||||
let output = tokio::process::Command::new("systemctl")
|
|
||||||
.args(["stop", unit])
|
|
||||||
.output()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !output.status.success() {
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
||||||
anyhow::bail!("systemctl stop {unit} failed: {stderr}");
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health(&self) -> HarnessHealth {
|
|
||||||
let url = format!("{}/health", self.endpoint);
|
|
||||||
let running = self.client.get(&url).send().await.is_ok();
|
|
||||||
HarnessHealth {
|
|
||||||
name: "mistralrs".into(),
|
|
||||||
running,
|
|
||||||
uptime_secs: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
|
||||||
let url = format!("{}/v1/models", self.endpoint);
|
|
||||||
let resp = self.client.get(&url).send().await?;
|
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
anyhow::bail!("GET /v1/models returned {}", resp.status());
|
|
||||||
}
|
|
||||||
|
|
||||||
let models_resp: ModelsResponse = resp.json().await?;
|
|
||||||
Ok(models_resp
|
|
||||||
.data
|
|
||||||
.into_iter()
|
|
||||||
.map(|m| ModelInfo {
|
|
||||||
id: m.id,
|
|
||||||
harness: "mistralrs".into(),
|
|
||||||
status: m.status.unwrap_or_else(|| "loaded".into()),
|
|
||||||
devices: vec![],
|
|
||||||
vram_used_mb: None,
|
|
||||||
})
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
|
||||||
let url = format!("{}/v1/models/reload", self.endpoint);
|
|
||||||
let resp = self
|
|
||||||
.client
|
|
||||||
.post(&url)
|
|
||||||
.json(&serde_json::json!({ "model_id": spec.model_id }))
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
let body = resp.text().await.unwrap_or_default();
|
|
||||||
anyhow::bail!("POST /v1/models/reload failed: {body}");
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
|
||||||
let url = format!("{}/v1/models/unload", self.endpoint);
|
|
||||||
let resp = self
|
|
||||||
.client
|
|
||||||
.post(&url)
|
|
||||||
.json(&serde_json::json!({ "model_id": model_id }))
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
let body = resp.text().await.unwrap_or_default();
|
|
||||||
anyhow::bail!("POST /v1/models/unload failed: {body}");
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn inference_endpoint(&self, _model_id: &str) -> Option<String> {
|
|
||||||
// mistral.rs routes internally by model name in the request body,
|
|
||||||
// so the inference endpoint is always the base URL.
|
|
||||||
Some(self.endpoint.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,15 +1,24 @@
|
|||||||
//! Harness registry — maps harness names to trait implementations.
|
//! Harness registry — maps harness names to trait implementations.
|
||||||
|
|
||||||
pub mod llamacpp;
|
pub mod arch;
|
||||||
pub mod mistralrs;
|
pub mod candle;
|
||||||
|
pub mod tp;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Registry of available harness implementations.
|
/// Registry of available harness implementations.
|
||||||
|
///
|
||||||
|
/// Holds an `Arc<dyn Harness>` per harness for generic lifecycle dispatch
|
||||||
|
/// (load/unload/list_models). When a candle harness is registered, a typed
|
||||||
|
/// `Arc<CandleHarness>` is also cached so inference routes can bypass the
|
||||||
|
/// dyn-Trait dispatch and reach harness-specific methods (chat completion,
|
||||||
|
/// streaming, etc.).
|
||||||
pub struct HarnessRegistry {
|
pub struct HarnessRegistry {
|
||||||
harnesses: HashMap<String, Box<dyn Harness>>,
|
harnesses: HashMap<String, Arc<dyn Harness>>,
|
||||||
|
candle: Option<Arc<candle::CandleHarness>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for HarnessRegistry {
|
impl Default for HarnessRegistry {
|
||||||
@@ -22,10 +31,11 @@ impl HarnessRegistry {
|
|||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
harnesses: HashMap::new(),
|
harnesses: HashMap::new(),
|
||||||
|
candle: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register(&mut self, harness: Box<dyn Harness>) {
|
pub fn register(&mut self, harness: Arc<dyn Harness>) {
|
||||||
self.harnesses.insert(harness.name().to_string(), harness);
|
self.harnesses.insert(harness.name().to_string(), harness);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,6 +44,12 @@ impl HarnessRegistry {
|
|||||||
self.harnesses.keys().cloned().collect()
|
self.harnesses.keys().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Typed handle to the candle harness, if registered. Used by inference
|
||||||
|
/// routes that need methods beyond the `Harness` trait surface.
|
||||||
|
pub fn candle(&self) -> Option<Arc<candle::CandleHarness>> {
|
||||||
|
self.candle.clone()
|
||||||
|
}
|
||||||
|
|
||||||
/// List models from all registered harnesses.
|
/// List models from all registered harnesses.
|
||||||
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
let mut all = Vec::new();
|
let mut all = Vec::new();
|
||||||
@@ -81,19 +97,25 @@ impl HarnessRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Build a registry from harness configs.
|
/// Build a registry from harness configs.
|
||||||
pub fn from_configs(configs: &[HarnessConfig]) -> Self {
|
///
|
||||||
|
/// `bind_url` is the URL where this neuron serves inference (its own
|
||||||
|
/// listen address). In-process harnesses (currently the only kind)
|
||||||
|
/// return this URL from `inference_endpoint`.
|
||||||
|
pub fn from_configs(
|
||||||
|
configs: &[HarnessConfig],
|
||||||
|
bind_url: &str,
|
||||||
|
settings: &crate::config::HarnessSettings,
|
||||||
|
) -> Self {
|
||||||
let mut registry = Self::new();
|
let mut registry = Self::new();
|
||||||
for config in configs {
|
for config in configs {
|
||||||
match config.name.as_str() {
|
match config.name.as_str() {
|
||||||
"mistralrs" => {
|
"candle" => {
|
||||||
if let Some(endpoint) = &config.endpoint {
|
let harness = Arc::new(candle::CandleHarness::new(
|
||||||
registry.register(Box::new(mistralrs::MistralRsHarness::new(
|
bind_url.to_string(),
|
||||||
endpoint.clone(),
|
settings.candle.hf_cache.clone(),
|
||||||
config.systemd_unit.clone(),
|
));
|
||||||
)));
|
registry.candle = Some(Arc::clone(&harness));
|
||||||
} else {
|
registry.harnesses.insert("candle".into(), harness);
|
||||||
tracing::warn!("mistralrs harness missing endpoint, skipping");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
other => {
|
other => {
|
||||||
tracing::warn!(harness = other, "unknown harness type, skipping");
|
tracing::warn!(harness = other, "unknown harness type, skipping");
|
||||||
|
|||||||
119
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
119
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
|
||||||
|
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
|
||||||
|
//!
|
||||||
|
//! Ported from the canonical
|
||||||
|
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
|
||||||
|
//! Row-parallel layers apply this op after their local matmul to sum
|
||||||
|
//! partial outputs across NCCL ranks.
|
||||||
|
//!
|
||||||
|
//! Available only under `--features cuda`; on CPU builds this module
|
||||||
|
//! is empty and row-parallel layers degenerate to local matmul only
|
||||||
|
//! (useful for compile-checking the model code; correctness requires
|
||||||
|
//! cuda).
|
||||||
|
//!
|
||||||
|
//! Thread-safety caveat: NCCL communicators are technically only
|
||||||
|
//! safe to use from a single thread at a time
|
||||||
|
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
|
||||||
|
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
|
||||||
|
//! against it from the dedicated `spawn_blocking` thread the inference
|
||||||
|
//! pipeline already uses for candle's forward passes.
|
||||||
|
|
||||||
|
#![cfg(feature = "cuda")]
|
||||||
|
|
||||||
|
use candle_core::backend::BackendStorage;
|
||||||
|
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
|
||||||
|
use cudarc::nccl::{Comm, ReduceOp};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
|
||||||
|
/// graph as a custom op. Each row-parallel layer holds one of these.
|
||||||
|
pub struct AllReduce {
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
|
||||||
|
// that issuing ops against one comm from multiple threads concurrently
|
||||||
|
// is unsafe. We serialise via the single spawn_blocking thread that
|
||||||
|
// drives the model's forward pass. The Send/Sync impl is necessary
|
||||||
|
// because candle's CustomOp1 trait bounds require it; the correctness
|
||||||
|
// invariant is enforced at the call site, not the type level.
|
||||||
|
unsafe impl Send for AllReduce {}
|
||||||
|
unsafe impl Sync for AllReduce {}
|
||||||
|
|
||||||
|
impl AllReduce {
|
||||||
|
pub fn new(comm: Arc<Comm>) -> Self {
|
||||||
|
Self { comm }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn comm(&self) -> &Arc<Comm> {
|
||||||
|
&self.comm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomOp1 for AllReduce {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"neuron.tp.all_reduce"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
|
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||||
|
// Reject non-contiguous inputs explicitly — copying them
|
||||||
|
// server-side would mask shape bugs (a TP layer feeding a
|
||||||
|
// strided activation into all_reduce is almost certainly a
|
||||||
|
// model construction error).
|
||||||
|
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
|
||||||
|
slice: &cudarc::driver::CudaSlice<T>,
|
||||||
|
l: &Layout,
|
||||||
|
) -> Result<()> {
|
||||||
|
match l.contiguous_offsets() {
|
||||||
|
Some((0, n)) if n == slice.len() => Ok(()),
|
||||||
|
_ => candle_core::bail!(
|
||||||
|
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
|
||||||
|
l,
|
||||||
|
slice.len()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let elem_count = l.shape().elem_count();
|
||||||
|
let dev = s.device().clone();
|
||||||
|
|
||||||
|
let out = match s.dtype() {
|
||||||
|
DType::BF16 => {
|
||||||
|
let src = s.as_cuda_slice::<bf16>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let src = s.as_cuda_slice::<f16>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let src = s.as_cuda_slice::<f32>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
dtype => candle_core::bail!(
|
||||||
|
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
Ok((out, l.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
213
crates/neuron/src/harness/tp/fused_load.rs
Normal file
213
crates/neuron/src/harness/tp/fused_load.rs
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
//! Direct safetensors readers for fused-region weight tensors.
|
||||||
|
//!
|
||||||
|
//! Qwen3-Next's `in_proj_qkv` and `conv1d` weights are *fused* —
|
||||||
|
//! three regions stored sequentially along dim 0 (`[key_q, key_k,
|
||||||
|
//! value]`). The per-rank shard for each region has unequal size
|
||||||
|
//! (`key_dim/ws` vs `value_dim/ws`), so candle's `ShardedSafeTensors`
|
||||||
|
//! built-in `Shard { dim, rank, world_size }` (uniform split) doesn't
|
||||||
|
//! map to the right slices.
|
||||||
|
//!
|
||||||
|
//! The previous approach loaded the full fused tensor onto the device,
|
||||||
|
//! `narrow`ed the three regions, and `Tensor::cat(...).contiguous()`'d
|
||||||
|
//! the per-rank slice. That left ~100 MB of transient device memory
|
||||||
|
//! per linear-attention layer — 48 layers × 100 MB = ~4.8 GB of
|
||||||
|
//! allocator pressure during load, enough to trigger fragmentation
|
||||||
|
//! OOM on tight-VRAM consumer GPUs.
|
||||||
|
//!
|
||||||
|
//! This module reads the three per-rank byte ranges *directly from
|
||||||
|
//! the safetensors mmap* (host-side), concatenates them into a single
|
||||||
|
//! contiguous byte buffer, and uploads as one device allocation. No
|
||||||
|
//! full-tensor device materialisation.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, bail};
|
||||||
|
use candle_core::safetensors::MmapedSafetensors;
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
|
||||||
|
/// Read a 2D fused-QKV tensor `[conv_dim, hidden_size]` and return
|
||||||
|
/// this rank's per-region slice as a `[per_rank_conv_dim, hidden_size]`
|
||||||
|
/// device tensor.
|
||||||
|
///
|
||||||
|
/// `tensor_name` must be the fully-qualified safetensors key (e.g.
|
||||||
|
/// `"model.language_model.layers.5.linear_attn.in_proj_qkv.weight"`).
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn load_fused_qkv_2d(
|
||||||
|
mmap: &MmapedSafetensors,
|
||||||
|
tensor_name: &str,
|
||||||
|
hidden_size: usize,
|
||||||
|
key_dim: usize,
|
||||||
|
value_dim: usize,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
target_dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let ws = world_size as usize;
|
||||||
|
let r = rank as usize;
|
||||||
|
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||||
|
must each be divisible by world_size ({ws})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let per_rank_key = key_dim / ws;
|
||||||
|
let per_rank_value = value_dim / ws;
|
||||||
|
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
|
||||||
|
|
||||||
|
let view = mmap
|
||||||
|
.get(tensor_name)
|
||||||
|
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 2D"))?;
|
||||||
|
let view_dtype: DType = view
|
||||||
|
.dtype()
|
||||||
|
.try_into()
|
||||||
|
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
|
||||||
|
|
||||||
|
let shape = view.shape();
|
||||||
|
if shape.len() != 2 {
|
||||||
|
bail!(
|
||||||
|
"fused qkv tensor '{tensor_name}' has shape {shape:?}, expected 2D \
|
||||||
|
[conv_dim, hidden_size]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let conv_dim = key_dim * 2 + value_dim;
|
||||||
|
if shape[0] != conv_dim || shape[1] != hidden_size {
|
||||||
|
bail!(
|
||||||
|
"fused qkv tensor '{tensor_name}' shape {shape:?} \
|
||||||
|
doesn't match expected [{conv_dim}, {hidden_size}]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
|
||||||
|
let k_bytes = slice_dim0_bytes(
|
||||||
|
&view,
|
||||||
|
key_dim + r * per_rank_key,
|
||||||
|
per_rank_key,
|
||||||
|
tensor_name,
|
||||||
|
"k",
|
||||||
|
)?;
|
||||||
|
let v_bytes = slice_dim0_bytes(
|
||||||
|
&view,
|
||||||
|
2 * key_dim + r * per_rank_value,
|
||||||
|
per_rank_value,
|
||||||
|
tensor_name,
|
||||||
|
"v",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
|
||||||
|
bytes.extend_from_slice(&q_bytes);
|
||||||
|
bytes.extend_from_slice(&k_bytes);
|
||||||
|
bytes.extend_from_slice(&v_bytes);
|
||||||
|
|
||||||
|
let tensor = Tensor::from_raw_buffer(
|
||||||
|
&bytes,
|
||||||
|
view_dtype,
|
||||||
|
&[per_rank_conv_dim, hidden_size],
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused qkv '{tensor_name}'"))?;
|
||||||
|
tensor
|
||||||
|
.to_dtype(target_dtype)
|
||||||
|
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read a 3D fused-QKV tensor `[conv_dim, 1, kernel_size]` (the
|
||||||
|
/// depthwise conv1d weight) and return this rank's per-region slice
|
||||||
|
/// as a `[per_rank_conv_dim, 1, kernel_size]` device tensor.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn load_fused_qkv_3d(
|
||||||
|
mmap: &MmapedSafetensors,
|
||||||
|
tensor_name: &str,
|
||||||
|
mid: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
key_dim: usize,
|
||||||
|
value_dim: usize,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
target_dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let ws = world_size as usize;
|
||||||
|
let r = rank as usize;
|
||||||
|
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||||
|
must each be divisible by world_size ({ws})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let per_rank_key = key_dim / ws;
|
||||||
|
let per_rank_value = value_dim / ws;
|
||||||
|
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
|
||||||
|
|
||||||
|
let view = mmap
|
||||||
|
.get(tensor_name)
|
||||||
|
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 3D"))?;
|
||||||
|
let view_dtype: DType = view
|
||||||
|
.dtype()
|
||||||
|
.try_into()
|
||||||
|
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
|
||||||
|
|
||||||
|
let shape = view.shape();
|
||||||
|
if shape.len() != 3 {
|
||||||
|
bail!(
|
||||||
|
"fused conv tensor '{tensor_name}' has shape {shape:?}, expected 3D \
|
||||||
|
[conv_dim, mid, kernel_size]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let conv_dim = key_dim * 2 + value_dim;
|
||||||
|
if shape[0] != conv_dim || shape[1] != mid || shape[2] != kernel_size {
|
||||||
|
bail!(
|
||||||
|
"fused conv tensor '{tensor_name}' shape {shape:?} \
|
||||||
|
doesn't match expected [{conv_dim}, {mid}, {kernel_size}]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
|
||||||
|
let k_bytes = slice_dim0_bytes(
|
||||||
|
&view,
|
||||||
|
key_dim + r * per_rank_key,
|
||||||
|
per_rank_key,
|
||||||
|
tensor_name,
|
||||||
|
"k",
|
||||||
|
)?;
|
||||||
|
let v_bytes = slice_dim0_bytes(
|
||||||
|
&view,
|
||||||
|
2 * key_dim + r * per_rank_value,
|
||||||
|
per_rank_value,
|
||||||
|
tensor_name,
|
||||||
|
"v",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
|
||||||
|
bytes.extend_from_slice(&q_bytes);
|
||||||
|
bytes.extend_from_slice(&k_bytes);
|
||||||
|
bytes.extend_from_slice(&v_bytes);
|
||||||
|
|
||||||
|
let tensor = Tensor::from_raw_buffer(
|
||||||
|
&bytes,
|
||||||
|
view_dtype,
|
||||||
|
&[per_rank_conv_dim, mid, kernel_size],
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused conv '{tensor_name}'"))?;
|
||||||
|
tensor
|
||||||
|
.to_dtype(target_dtype)
|
||||||
|
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read `len` consecutive rows along dim 0 starting at `start` from
|
||||||
|
/// the safetensors view, returning the raw bytes. Wraps the same
|
||||||
|
/// `view.slice(start..stop)` machinery that candle's
|
||||||
|
/// `ShardedSafeTensors::get` uses internally.
|
||||||
|
fn slice_dim0_bytes(
|
||||||
|
view: &safetensors::tensor::TensorView<'_>,
|
||||||
|
start: usize,
|
||||||
|
len: usize,
|
||||||
|
tensor_name: &str,
|
||||||
|
region: &str,
|
||||||
|
) -> Result<Vec<u8>> {
|
||||||
|
use safetensors::slice::IndexOp;
|
||||||
|
let stop = start + len;
|
||||||
|
let iter = view.slice(start..stop).map_err(|e| {
|
||||||
|
anyhow::anyhow!("slice '{tensor_name}' region {region} ({start}..{stop}): {e:?}")
|
||||||
|
})?;
|
||||||
|
Ok(iter.into_iter().flatten().copied().collect())
|
||||||
|
}
|
||||||
791
crates/neuron/src/harness/tp/mod.rs
Normal file
791
crates/neuron/src/harness/tp/mod.rs
Normal file
@@ -0,0 +1,791 @@
|
|||||||
|
//! Tensor-parallel inference plumbing.
|
||||||
|
//!
|
||||||
|
//! The leader process (the neuron daemon proper) drives one
|
||||||
|
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
|
||||||
|
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
|
||||||
|
//! and talks to each over a newline-delimited JSON RPC channel on
|
||||||
|
//! the worker's stdin/stdout (see `rpc.rs`).
|
||||||
|
//!
|
||||||
|
//! Sub-staging:
|
||||||
|
//!
|
||||||
|
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
|
||||||
|
//! forks N workers; `ping` round-trips every worker to confirm
|
||||||
|
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
|
||||||
|
//! `NcclSanityCheck` are stubbed.
|
||||||
|
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
|
||||||
|
//! `NcclSanityCheck`. CUDA-gated.
|
||||||
|
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
||||||
|
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
||||||
|
|
||||||
|
pub mod all_reduce;
|
||||||
|
pub mod fused_load;
|
||||||
|
pub mod nccl_state;
|
||||||
|
pub mod rpc;
|
||||||
|
pub mod tp_linear;
|
||||||
|
pub mod tp_qwen3;
|
||||||
|
pub mod tp_qwen3_5;
|
||||||
|
pub mod worker;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::Stdio;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
|
||||||
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||||
|
|
||||||
|
use rpc::{WorkerRequest, WorkerResponse};
|
||||||
|
|
||||||
|
/// Leader-side handle for any TP-loaded model. The pool's
|
||||||
|
/// `load_dense_shard` dispatches on `config.json#/model_type` to build
|
||||||
|
/// the right variant; downstream callers (the harness's
|
||||||
|
/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`,
|
||||||
|
/// `unload_model`) all hold this enum and let the variant dispatch
|
||||||
|
/// determine the concrete forward.
|
||||||
|
///
|
||||||
|
/// Variants gated on `cuda` because the underlying TP models hold
|
||||||
|
/// `Arc<cudarc::nccl::Comm>` references — irrelevant on CPU builds.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub enum TpLeaderModel {
|
||||||
|
Qwen3(tp_qwen3::TpQwen3ForCausalLM),
|
||||||
|
Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
impl TpLeaderModel {
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
input: &candle_core::Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
TpLeaderModel::Qwen3(m) => m.forward(input, offset),
|
||||||
|
TpLeaderModel::Qwen3_5(m) => m.forward(input, offset),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
match self {
|
||||||
|
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
|
||||||
|
TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &candle_core::Device {
|
||||||
|
match self {
|
||||||
|
TpLeaderModel::Qwen3(m) => m.device(),
|
||||||
|
TpLeaderModel::Qwen3_5(m) => m.device(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One worker subprocess plus its bidirectional stdio handles.
|
||||||
|
struct Worker {
|
||||||
|
rank: u32,
|
||||||
|
/// Captured so the leader can log "spawned rank N on device M" and
|
||||||
|
/// future stages can re-issue Init after a CUDA reset. Unused in
|
||||||
|
/// the Stage 7a-i RPC paths themselves.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
cuda_device: u32,
|
||||||
|
child: Child,
|
||||||
|
stdin: ChildStdin,
|
||||||
|
stdout: Lines<BufReader<ChildStdout>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Worker {
|
||||||
|
/// Send a request and wait for the response. Used for sequenced
|
||||||
|
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
|
||||||
|
/// overlap the worker's execution with the leader's.
|
||||||
|
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
|
||||||
|
self.send_only(req).await?;
|
||||||
|
self.recv_only().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write a request without awaiting its response. Pair with
|
||||||
|
/// `recv_only` from the caller when leader and worker need to do
|
||||||
|
/// work concurrently — e.g. during `Init`, where the leader
|
||||||
|
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
|
||||||
|
/// workers, then collects `InitOk` after NCCL completes.
|
||||||
|
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
|
||||||
|
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
|
||||||
|
line.push('\n');
|
||||||
|
self.stdin
|
||||||
|
.write_all(line.as_bytes())
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("write request to rank {}", self.rank))?;
|
||||||
|
self.stdin
|
||||||
|
.flush()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn recv_only(&mut self) -> Result<WorkerResponse> {
|
||||||
|
let reply = self
|
||||||
|
.stdout
|
||||||
|
.next_line()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("read reply from rank {}", self.rank))?
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
|
||||||
|
serde_json::from_str(&reply)
|
||||||
|
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drain one response from every worker, classifying each via the
|
||||||
|
/// supplied checker. Always reads from every worker — even if some
|
||||||
|
/// fail — so the next call's recv doesn't pick up stale responses
|
||||||
|
/// from this one (pipe-poisoning was the cause of the
|
||||||
|
/// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class
|
||||||
|
/// of bugs).
|
||||||
|
///
|
||||||
|
/// Returns a vector of `rank N: detail` strings for any worker that
|
||||||
|
/// errored, expected-mismatched, or failed to respond. Caller decides
|
||||||
|
/// how to combine these with the leader's outcome.
|
||||||
|
async fn drain_workers(
|
||||||
|
workers: &mut [Worker],
|
||||||
|
mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>,
|
||||||
|
) -> Vec<String> {
|
||||||
|
let mut errs = Vec::new();
|
||||||
|
for w in workers {
|
||||||
|
match w.recv_only().await {
|
||||||
|
Ok(resp) => {
|
||||||
|
if let Err(detail) = check(resp) {
|
||||||
|
errs.push(format!("rank {} {detail}", w.rank));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errs
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Combine a leader's `Result<Result<T>>` (the typical
|
||||||
|
/// `spawn_blocking → JoinHandle<Result<T>>` shape) with the worker
|
||||||
|
/// drain results into a single `Result<T>`. Leader failures take
|
||||||
|
/// precedence in the error message but worker errors get appended so
|
||||||
|
/// the operator sees both halves.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn combine_leader_workers<T>(
|
||||||
|
leader: Result<Result<T>>,
|
||||||
|
worker_errors: Vec<String>,
|
||||||
|
op: &str,
|
||||||
|
) -> Result<T> {
|
||||||
|
match leader {
|
||||||
|
Ok(Ok(value)) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Ok(value)
|
||||||
|
} else {
|
||||||
|
anyhow::bail!(
|
||||||
|
"{op}: leader succeeded but workers failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Err(e.context(format!("{op}: leader forward failed")))
|
||||||
|
} else {
|
||||||
|
Err(e.context(format!(
|
||||||
|
"{op}: leader forward failed and workers also failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(panic_err) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Err(panic_err)
|
||||||
|
} else {
|
||||||
|
Err(panic_err.context(format!(
|
||||||
|
"{op}: leader task panicked and workers failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A live pool of worker subprocesses. Owns the `Child` handles so
|
||||||
|
/// dropping the pool kills the children; explicit `shutdown()` is
|
||||||
|
/// the graceful path.
|
||||||
|
pub struct WorkerPool {
|
||||||
|
world_size: u32,
|
||||||
|
workers: Vec<Worker>,
|
||||||
|
/// Path to the neuron binary used to launch workers.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
exe: PathBuf,
|
||||||
|
/// Leader's own NCCL rank-0 state. Defaults to empty; populated by
|
||||||
|
/// `init_nccl()`. Held here so the leader can participate in
|
||||||
|
/// collectives (rank 0) without spawning a fourth subprocess.
|
||||||
|
leader_nccl: nccl_state::NcclState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerPool {
|
||||||
|
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
||||||
|
/// leader (in-process) and is *not* spawned here — the leader
|
||||||
|
/// holds rank 0's NCCL Comm and shard in its own address space.
|
||||||
|
///
|
||||||
|
/// `binary` is the path to the neuron executable to run for each
|
||||||
|
/// worker (production passes `/proc/self/exe`; tests pass the
|
||||||
|
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
|
||||||
|
/// `cuda_devices` is one entry per rank including rank 0. Worker
|
||||||
|
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
|
||||||
|
pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result<Self> {
|
||||||
|
if world_size < 2 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"WorkerPool::spawn called with world_size={world_size}; \
|
||||||
|
use the single-process path for world_size < 2"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if cuda_devices.len() as u32 != world_size {
|
||||||
|
anyhow::bail!(
|
||||||
|
"expected {world_size} cuda_devices entries, got {}",
|
||||||
|
cuda_devices.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let exe = binary.to_path_buf();
|
||||||
|
|
||||||
|
let mut workers = Vec::with_capacity(world_size as usize - 1);
|
||||||
|
// Rank 0 stays in-process. Spawn ranks 1..world_size.
|
||||||
|
for rank in 1..world_size {
|
||||||
|
let cuda_device = cuda_devices[rank as usize];
|
||||||
|
let mut cmd = Command::new(&exe);
|
||||||
|
cmd.arg("--worker")
|
||||||
|
.arg("--rank")
|
||||||
|
.arg(rank.to_string())
|
||||||
|
.arg("--tp-size")
|
||||||
|
.arg(world_size.to_string())
|
||||||
|
.arg("--cuda-device")
|
||||||
|
.arg(cuda_device.to_string())
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
// Inherit stderr so worker tracing surfaces alongside
|
||||||
|
// the leader's journalctl stream.
|
||||||
|
.stderr(Stdio::inherit())
|
||||||
|
.kill_on_drop(true);
|
||||||
|
|
||||||
|
let mut child = cmd
|
||||||
|
.spawn()
|
||||||
|
.with_context(|| format!("spawn worker rank {rank}"))?;
|
||||||
|
let stdin = child
|
||||||
|
.stdin
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
|
||||||
|
let stdout = child
|
||||||
|
.stdout
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
|
||||||
|
let stdout = BufReader::new(stdout).lines();
|
||||||
|
|
||||||
|
workers.push(Worker {
|
||||||
|
rank,
|
||||||
|
cuda_device,
|
||||||
|
child,
|
||||||
|
stdin,
|
||||||
|
stdout,
|
||||||
|
});
|
||||||
|
tracing::info!(rank, cuda_device, "spawned tp worker");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
world_size,
|
||||||
|
workers,
|
||||||
|
exe,
|
||||||
|
leader_nccl: nccl_state::NcclState::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Establish the NCCL communicator across the leader (rank 0) and
|
||||||
|
/// every worker subprocess. Rendezvous is via a freshly-generated
|
||||||
|
/// `Id` broadcast over the RPC stream; the actual handshake blocks
|
||||||
|
/// inside `Comm::from_rank` until all `world_size` ranks check in.
|
||||||
|
///
|
||||||
|
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
|
||||||
|
/// to — typically the first entry of the `cuda_devices` slice
|
||||||
|
/// originally passed to `spawn()`.
|
||||||
|
///
|
||||||
|
/// On the non-cuda build this immediately fails because the leader
|
||||||
|
/// can't generate an `Id` without libnccl. The same call works in
|
||||||
|
/// the worker path (returning a no-cuda error response) so the
|
||||||
|
/// failure surface is uniform.
|
||||||
|
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
|
||||||
|
let comm_id = nccl_state::generate_comm_id_hex()
|
||||||
|
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
|
||||||
|
|
||||||
|
// 1. Write Init to every worker's stdin without awaiting the
|
||||||
|
// response. Workers will parse and call Comm::from_rank
|
||||||
|
// concurrently with the leader below.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: comm_id.clone(),
|
||||||
|
};
|
||||||
|
w.send_only(&req).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader rank 0 calls Comm::from_rank on its own device.
|
||||||
|
// Runs on spawn_blocking because NCCL's init blocks until
|
||||||
|
// every rank has called in — that's exactly the workers
|
||||||
|
// above. The leader's NcclState is moved through the
|
||||||
|
// blocking task and returned to the pool.
|
||||||
|
let leader_cfg = worker::WorkerConfig {
|
||||||
|
rank: 0,
|
||||||
|
world_size: self.world_size,
|
||||||
|
cuda_device: leader_cuda_device,
|
||||||
|
};
|
||||||
|
let comm_id_for_leader = comm_id.clone();
|
||||||
|
// Swap out the leader's NcclState into a fresh empty one so we
|
||||||
|
// can move it into spawn_blocking; restore after the task
|
||||||
|
// returns. (NcclState isn't Clone — it owns a real NCCL Comm.)
|
||||||
|
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
||||||
|
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
||||||
|
let resp = leader_state.init(leader_cfg, &comm_id_for_leader);
|
||||||
|
(leader_state, resp)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context("leader NCCL init task panicked")?;
|
||||||
|
self.leader_nccl = returned_state;
|
||||||
|
match leader_resp {
|
||||||
|
rpc::WorkerResponse::InitOk => {}
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Read InitOk from each worker. By now every worker has
|
||||||
|
// completed its Comm::from_rank call (NCCL released them
|
||||||
|
// when the leader joined the handshake) and is writing its
|
||||||
|
// response.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match &resp {
|
||||||
|
rpc::WorkerResponse::InitOk => {}
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} init: expected InitOk, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::info!(
|
||||||
|
world_size = self.world_size,
|
||||||
|
"NCCL communicator established across all ranks"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate the NCCL communicator: every rank `all_reduce`s a
|
||||||
|
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
|
||||||
|
/// `world_size`. Confirms the handshake is live, not just
|
||||||
|
/// configured.
|
||||||
|
///
|
||||||
|
/// Must be called after `init_nccl()`; before that the leader has
|
||||||
|
/// no Comm and the workers reply with `nccl_not_initialised`.
|
||||||
|
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
|
||||||
|
// 1. Trigger the all_reduce on every worker (write-only).
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader's own all_reduce, in spawn_blocking. NCCL operations
|
||||||
|
// block until every rank participates.
|
||||||
|
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
||||||
|
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
||||||
|
let resp = leader_state.sanity_check();
|
||||||
|
(leader_state, resp)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context("leader NCCL sanity task panicked")?;
|
||||||
|
self.leader_nccl = returned_state;
|
||||||
|
|
||||||
|
let expected = self.world_size;
|
||||||
|
let leader_sum = match leader_resp {
|
||||||
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
|
||||||
|
};
|
||||||
|
if leader_sum != expected {
|
||||||
|
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Read sanity result from each worker. All must match
|
||||||
|
// world_size — anything else means the collective didn't
|
||||||
|
// complete consistently across ranks.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
rpc::WorkerResponse::NcclSanityResult { observed_sum }
|
||||||
|
if observed_sum == expected => {}
|
||||||
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||||
|
anyhow::bail!(
|
||||||
|
"worker rank {} observed_sum={observed_sum}, expected {expected}",
|
||||||
|
w.rank
|
||||||
|
);
|
||||||
|
}
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::info!(
|
||||||
|
world_size = expected,
|
||||||
|
"NCCL sanity check OK across all ranks"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ping every worker and return their Pong payloads in rank order.
|
||||||
|
/// Useful right after `spawn` to confirm the lifecycle plumbing is
|
||||||
|
/// intact before kicking off any heavier work.
|
||||||
|
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
|
||||||
|
let mut out = Vec::with_capacity(self.workers.len());
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.request(&WorkerRequest::Ping).await?;
|
||||||
|
match &resp {
|
||||||
|
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
|
||||||
|
WorkerResponse::Pong { rank, .. } => {
|
||||||
|
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
|
||||||
|
}
|
||||||
|
out.push(resp);
|
||||||
|
}
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load this rank's shard of a dense Qwen3 model on every rank.
|
||||||
|
///
|
||||||
|
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
|
||||||
|
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
|
||||||
|
/// shards in their own address spaces and confirm via
|
||||||
|
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
|
||||||
|
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
|
||||||
|
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
|
||||||
|
/// of the full model (plus the replicated embedding/norm/lm_head).
|
||||||
|
///
|
||||||
|
/// `leader_device` is the candle `Device` the leader's shard lives
|
||||||
|
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
|
||||||
|
/// the same index passed to `init_nccl`. `dtype` is the on-device
|
||||||
|
/// element type; bf16 is the canonical Qwen3 distribution dtype.
|
||||||
|
///
|
||||||
|
/// `init_nccl` must have completed first. Bails if the leader's
|
||||||
|
/// NCCL comm isn't set up yet.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn load_dense_shard(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
config_json: &str,
|
||||||
|
safetensors_paths: &[std::path::PathBuf],
|
||||||
|
leader_device: &candle_core::Device,
|
||||||
|
dtype: candle_core::DType,
|
||||||
|
quant: Option<String>,
|
||||||
|
) -> Result<std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>> {
|
||||||
|
use candle_nn::var_builder::ShardedSafeTensors;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
// Wrap the comm in SendComm immediately so it stays Send across
|
||||||
|
// the await points in this method — bare Arc<Comm> would
|
||||||
|
// poison the async fn's Send bound (Comm's raw NCCL pointer is
|
||||||
|
// !Send). The wrapper's safety contract is satisfied by the
|
||||||
|
// pool's outer Mutex serialising callers + the spawn_blocking
|
||||||
|
// thread being the only place ops are issued.
|
||||||
|
let leader_comm =
|
||||||
|
nccl_state::SendComm(self.leader_nccl.comm().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("leader NCCL not initialised; call init_nccl first")
|
||||||
|
})?);
|
||||||
|
let world_size = self.world_size;
|
||||||
|
let safetensors_str: Vec<String> = safetensors_paths
|
||||||
|
.iter()
|
||||||
|
.map(|p| p.to_string_lossy().into_owned())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 1. Fan out the LoadDenseShard request to every worker without
|
||||||
|
// awaiting their replies — they'll build their shards in
|
||||||
|
// parallel with the leader below.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::LoadDenseShard {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
config_json: config_json.to_string(),
|
||||||
|
safetensors_paths: safetensors_str.clone(),
|
||||||
|
quant: quant.clone(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Build rank 0's shard on the leader. Dispatch on model_type
|
||||||
|
// — for `qwen3` we build a `TpQwen3ForCausalLM`, for
|
||||||
|
// `qwen3_5` (Qwen3-Next, Qwen3.6's architecture) we build
|
||||||
|
// `TpQwen3_5ForCausalLM`. Both end up wrapped in the
|
||||||
|
// `TpLeaderModel` enum so downstream callers don't care.
|
||||||
|
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.get("model_type"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
let paths_for_leader: Vec<std::path::PathBuf> = safetensors_paths.to_vec();
|
||||||
|
let device_for_leader = leader_device.clone();
|
||||||
|
let comm_for_leader = leader_comm;
|
||||||
|
let model_id_for_log = model_id.to_string();
|
||||||
|
let config_json_for_leader = config_json.to_string();
|
||||||
|
let quant_for_leader = quant.clone();
|
||||||
|
|
||||||
|
let leader_model = tokio::task::spawn_blocking(move || -> Result<TpLeaderModel> {
|
||||||
|
// SAFETY: same invariant as the single-GPU dense path —
|
||||||
|
// the HF cache files are treated as immutable while the
|
||||||
|
// mmap is held.
|
||||||
|
let vb = unsafe {
|
||||||
|
ShardedSafeTensors::var_builder(&paths_for_leader, dtype, &device_for_leader)
|
||||||
|
.context("build ShardedVarBuilder over safetensors")?
|
||||||
|
};
|
||||||
|
// SAFETY: as above — the HF cache files are immutable.
|
||||||
|
let mmap = unsafe {
|
||||||
|
candle_core::safetensors::MmapedSafetensors::multi(&paths_for_leader)
|
||||||
|
.context("build MmapedSafetensors for leader load")?
|
||||||
|
};
|
||||||
|
let comm = comm_for_leader.into_inner();
|
||||||
|
let loaded = match model_type.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let cfg: super::tp::tp_qwen3::Config = serde_json::from_str(&config_json_for_leader)
|
||||||
|
.context("parse Qwen3 Config JSON for leader load")?;
|
||||||
|
TpLeaderModel::Qwen3(super::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
||||||
|
&cfg, &vb, 0, world_size, comm,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
"qwen3_5" => {
|
||||||
|
let cfg: super::tp::tp_qwen3_5::Config =
|
||||||
|
serde_json::from_str(&config_json_for_leader)
|
||||||
|
.context("parse Qwen3-Next Config JSON for leader load")?;
|
||||||
|
let quant_dtype =
|
||||||
|
super::tp::worker::parse_quant_string(quant_for_leader.as_deref())?;
|
||||||
|
TpLeaderModel::Qwen3_5(super::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
||||||
|
cfg,
|
||||||
|
&vb,
|
||||||
|
&mmap,
|
||||||
|
0,
|
||||||
|
world_size,
|
||||||
|
comm,
|
||||||
|
quant_dtype,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
tracing::info!(rank = 0, model = %model_id_for_log, model_type = %model_type, "loaded TP shard (leader)");
|
||||||
|
Ok(loaded)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context("leader load task panicked")??;
|
||||||
|
|
||||||
|
// 3. Collect worker confirmations. Anything other than
|
||||||
|
// LoadDenseShardOk aborts the whole load — the leader's
|
||||||
|
// already-loaded shard drops when this fn returns Err.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::LoadDenseShardOk => {}
|
||||||
|
WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Arc::new(Mutex::new(leader_model)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run one forward step across every rank. The leader's forward
|
||||||
|
/// returns the last-position logits as a candle Tensor on the
|
||||||
|
/// leader's device; the caller does sampling out-of-band. Workers
|
||||||
|
/// run their own forwards (the AllReduce inside row-parallel layers
|
||||||
|
/// is what lets the leader's collective complete) and reply with
|
||||||
|
/// `GenerateStepOk` — they do not ship logits over the wire.
|
||||||
|
///
|
||||||
|
/// `tokens` is the input for this step (prompt for prefill, the
|
||||||
|
/// previously-sampled token for decode). `offset` is the KV-cache
|
||||||
|
/// position before this step.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn generate_step(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<candle_core::Tensor> {
|
||||||
|
let step_start = std::time::Instant::now();
|
||||||
|
let tokens_len = tokens.len();
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
offset,
|
||||||
|
"WorkerPool::generate_step: fan-out"
|
||||||
|
);
|
||||||
|
// 1. Fan-out to workers.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::GenerateStep {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
tokens: tokens.clone(),
|
||||||
|
offset,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader's forward in spawn_blocking. The AllReduce CustomOps
|
||||||
|
// inside the row-parallel layers block until every worker's
|
||||||
|
// forward issues the matching collective.
|
||||||
|
let leader_start = std::time::Instant::now();
|
||||||
|
let leader_result = tokio::task::spawn_blocking(move || -> Result<candle_core::Tensor> {
|
||||||
|
let mut model = leader_model.blocking_lock();
|
||||||
|
let device = model.device().clone();
|
||||||
|
let input = candle_core::Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
|
// ForCausalLM::forward returns [B, 1, V] — squeeze both
|
||||||
|
// leading dims to the rank-1 vocab logits the sampler wants.
|
||||||
|
let logits = model.forward(&input, offset)?.squeeze(0)?.squeeze(0)?;
|
||||||
|
Ok(logits)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context("leader forward task panicked");
|
||||||
|
let leader_ok = matches!(leader_result, Ok(Ok(_)));
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
leader_ms = leader_start.elapsed().as_millis(),
|
||||||
|
leader_ok,
|
||||||
|
"WorkerPool::generate_step: leader forward returned"
|
||||||
|
);
|
||||||
|
|
||||||
|
// 3. ALWAYS drain worker responses, regardless of whether the
|
||||||
|
// leader succeeded. Skipping this on the leader's error
|
||||||
|
// path leaves stale GenerateStepOk replies in the worker
|
||||||
|
// pipes that poison the NEXT request's recv (was seeing
|
||||||
|
// "ClearKvCache: expected KvCacheCleared, got
|
||||||
|
// GenerateStepOk" the call after any forward-time failure).
|
||||||
|
let drain_start = std::time::Instant::now();
|
||||||
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||||
|
WorkerResponse::GenerateStepOk => Ok(()),
|
||||||
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||||
|
other => Err(format!("expected GenerateStepOk, got {other:?}")),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
drain_ms = drain_start.elapsed().as_millis(),
|
||||||
|
errors = worker_errors.len(),
|
||||||
|
total_ms = step_start.elapsed().as_millis(),
|
||||||
|
"WorkerPool::generate_step: workers drained"
|
||||||
|
);
|
||||||
|
|
||||||
|
combine_leader_workers(leader_result, worker_errors, "GenerateStep")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the KV cache for `model_id` on every rank. Called at the
|
||||||
|
/// start of every inference so a fresh request doesn't attend over
|
||||||
|
/// the previous one's tokens.
|
||||||
|
pub async fn clear_kv_cache(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
#[cfg(feature = "cuda")] leader_model: std::sync::Arc<tokio::sync::Mutex<TpLeaderModel>>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::ClearKvCache {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
let mut m = leader_model.lock().await;
|
||||||
|
m.clear_kv_cache();
|
||||||
|
}
|
||||||
|
// Drain workers — same rationale as `generate_step`. The
|
||||||
|
// leader's clear_kv_cache is in-process and infallible, but we
|
||||||
|
// still always drain so an error on one worker doesn't leave
|
||||||
|
// pending responses for the others.
|
||||||
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||||
|
WorkerResponse::KvCacheCleared => Ok(()),
|
||||||
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||||
|
other => Err(format!("expected KvCacheCleared, got {other:?}")),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
errors = worker_errors.len(),
|
||||||
|
"WorkerPool::clear_kv_cache: workers drained"
|
||||||
|
);
|
||||||
|
if !worker_errors.is_empty() {
|
||||||
|
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drop this model's shards on every rank. The leader's shard is
|
||||||
|
/// expected to have been dropped by the caller (its `Arc` was held
|
||||||
|
/// in the TpLoadedModel and goes away when that's removed).
|
||||||
|
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::UnloadModel {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::Unloaded => {}
|
||||||
|
WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
||||||
|
/// children. Best-effort — individual worker failures are logged
|
||||||
|
/// but don't abort the rest of the sweep.
|
||||||
|
pub async fn shutdown(mut self) -> Result<()> {
|
||||||
|
for w in &mut self.workers {
|
||||||
|
match w.request(&WorkerRequest::Shutdown).await {
|
||||||
|
Ok(WorkerResponse::Bye) => {}
|
||||||
|
Ok(other) => tracing::warn!(
|
||||||
|
rank = w.rank,
|
||||||
|
response = ?other,
|
||||||
|
"expected Bye on shutdown"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for w in &mut self.workers {
|
||||||
|
match w.child.wait().await {
|
||||||
|
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
|
||||||
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn world_size(&self) -> u32 {
|
||||||
|
self.world_size
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_path(&self) -> &PathBuf {
|
||||||
|
&self.exe
|
||||||
|
}
|
||||||
|
}
|
||||||
293
crates/neuron/src/harness/tp/nccl_state.rs
Normal file
293
crates/neuron/src/harness/tp/nccl_state.rs
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
//! NCCL state held by both the worker process and the leader's pool.
|
||||||
|
//!
|
||||||
|
//! Split into its own module so the worker (`tp/worker.rs`) and the
|
||||||
|
//! leader (`tp/mod.rs`) share the same hex-encoding/decoding code and
|
||||||
|
//! the same shape of `Option<Comm>` state machine.
|
||||||
|
//!
|
||||||
|
//! When the `cuda` feature is off, `NcclState` is a zero-sized
|
||||||
|
//! placeholder that returns `Error{kind="cuda_feature_not_enabled"}`
|
||||||
|
//! from every operation. When it's on, the same struct holds the
|
||||||
|
//! actual `cudarc::nccl::Comm`.
|
||||||
|
|
||||||
|
use super::rpc::WorkerResponse;
|
||||||
|
use super::worker::WorkerConfig;
|
||||||
|
|
||||||
|
/// Encode bytes as lowercase hex. Used for ferrying NCCL `Id::internal()`
|
||||||
|
/// across the leader→worker RPC boundary inside a JSON string.
|
||||||
|
pub fn encode_hex(bytes: &[u8]) -> String {
|
||||||
|
let mut out = String::with_capacity(bytes.len() * 2);
|
||||||
|
for b in bytes {
|
||||||
|
use std::fmt::Write;
|
||||||
|
let _ = write!(out, "{b:02x}");
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode lowercase-or-uppercase hex into bytes. Errors on odd length
|
||||||
|
/// or non-hex characters; the caller bubbles those up via the RPC's
|
||||||
|
/// `Error{kind="bad_request"}` variant.
|
||||||
|
pub fn decode_hex(s: &str) -> Result<Vec<u8>, String> {
|
||||||
|
if !s.len().is_multiple_of(2) {
|
||||||
|
return Err(format!("hex string has odd length {}", s.len()));
|
||||||
|
}
|
||||||
|
(0..s.len())
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| {
|
||||||
|
u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("bad hex byte at {i}: {e}"))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub struct NcclState;
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
impl Default for NcclState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
impl NcclState {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, _cfg: WorkerConfig, _comm_id_hex: &str) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "this neuron binary was built without --features cuda; \
|
||||||
|
NCCL Init requires CUDA"
|
||||||
|
.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sanity_check(&mut self) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "NCCL sanity check requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
mod cuda_impl {
|
||||||
|
use super::*;
|
||||||
|
use cudarc::driver::CudaContext;
|
||||||
|
use cudarc::nccl::{Comm, Id, ReduceOp};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Number of bytes in NCCL's unique-id type; matches `Id::internal()`'s
|
||||||
|
/// `[c_char; 128]`. Wire-encoded as 256 lowercase hex chars.
|
||||||
|
const NCCL_ID_BYTES: usize = 128;
|
||||||
|
|
||||||
|
pub struct NcclState {
|
||||||
|
/// Wrapped in `Arc` so we can hand a clone to `TpQwen3ForCausalLM`
|
||||||
|
/// at load time (every row-parallel layer needs a reference to
|
||||||
|
/// run its trailing `AllReduce`). The `Arc` is the single source
|
||||||
|
/// of truth for the comm's lifetime — when the pool drops and
|
||||||
|
/// every layer that captured a clone drops, NCCL releases the
|
||||||
|
/// underlying `ncclComm_t`.
|
||||||
|
comm: Option<Arc<Comm>>,
|
||||||
|
/// Held alongside the Comm so the device isn't dropped
|
||||||
|
/// underneath the NCCL handle.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ctx: Option<Arc<CudaContext>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for NcclState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NcclState {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
comm: None,
|
||||||
|
ctx: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clone the comm out as an `Arc` so callers (the leader-side
|
||||||
|
/// `TpQwen3ForCausalLM::load`, or the worker's own model load)
|
||||||
|
/// can hold a reference for the lifetime of the model. Returns
|
||||||
|
/// `None` before `init` has run.
|
||||||
|
pub fn comm(&self) -> Option<Arc<Comm>> {
|
||||||
|
self.comm.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
|
||||||
|
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
|
||||||
|
/// given comm must be serialised", not "the handle must stay on the
|
||||||
|
/// thread that created it" — so it's safe to move an `Arc<Comm>`
|
||||||
|
/// across threads as long as no concurrent ops are issued. The
|
||||||
|
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
|
||||||
|
/// wrapper at the move boundary is the only thing missing.
|
||||||
|
///
|
||||||
|
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
|
||||||
|
/// by the row-parallel layers are only used from the
|
||||||
|
/// `spawn_blocking` thread driving the forward pass; concurrent
|
||||||
|
/// access from another thread would still be a bug.
|
||||||
|
pub struct SendComm(pub Arc<Comm>);
|
||||||
|
|
||||||
|
// SAFETY: see the doc-comment above; the invariant is enforced at
|
||||||
|
// the call site (pool Mutex + single spawn_blocking thread), not at
|
||||||
|
// the type level.
|
||||||
|
unsafe impl Send for SendComm {}
|
||||||
|
unsafe impl Sync for SendComm {}
|
||||||
|
|
||||||
|
impl SendComm {
|
||||||
|
pub fn into_inner(self) -> Arc<Comm> {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
||||||
|
// (libnccl-allocated state). NCCL requires that operations against
|
||||||
|
// one Comm be issued one at a time; we serialise access by storing
|
||||||
|
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
|
||||||
|
// move-safe — NCCL doesn't track the calling OS thread, only the
|
||||||
|
// stream the operations are dispatched against.
|
||||||
|
unsafe impl Send for NcclState {}
|
||||||
|
unsafe impl Sync for NcclState {}
|
||||||
|
|
||||||
|
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
|
||||||
|
/// the leader to mint the shared communicator id which is then
|
||||||
|
/// broadcast to every worker via the RPC `Init` message.
|
||||||
|
pub fn generate_comm_id_hex() -> Result<String, String> {
|
||||||
|
// NcclError lacks a Display impl in cudarc 0.19.x — surface
|
||||||
|
// via Debug throughout this module.
|
||||||
|
let id = Id::new().map_err(|e| format!("Id::new(): {e:?}"))?;
|
||||||
|
let bytes_u8: [u8; NCCL_ID_BYTES] = std::array::from_fn(|i| id.internal()[i] as u8);
|
||||||
|
Ok(encode_hex(&bytes_u8))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NcclState {
|
||||||
|
pub fn init(&mut self, cfg: WorkerConfig, comm_id_hex: &str) -> WorkerResponse {
|
||||||
|
match try_init(self, cfg, comm_id_hex) {
|
||||||
|
Ok(()) => WorkerResponse::InitOk,
|
||||||
|
Err(msg) => WorkerResponse::Error {
|
||||||
|
kind: "nccl_init_failed".into(),
|
||||||
|
message: msg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sanity_check(&mut self) -> WorkerResponse {
|
||||||
|
let Some(comm) = self.comm.as_ref() else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "nccl_not_initialised".into(),
|
||||||
|
message: "sanity_check requires Init to have completed first".into(),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
match try_sanity_check(comm.as_ref()) {
|
||||||
|
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
|
||||||
|
Err(msg) => WorkerResponse::Error {
|
||||||
|
kind: "nccl_sanity_failed".into(),
|
||||||
|
message: msg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_init(state: &mut NcclState, cfg: WorkerConfig, comm_id_hex: &str) -> Result<(), String> {
|
||||||
|
let bytes = decode_hex(comm_id_hex)?;
|
||||||
|
if bytes.len() != NCCL_ID_BYTES {
|
||||||
|
return Err(format!(
|
||||||
|
"comm_id is {} bytes, expected {NCCL_ID_BYTES}",
|
||||||
|
bytes.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let id_bytes: [std::ffi::c_char; NCCL_ID_BYTES] =
|
||||||
|
std::array::from_fn(|i| bytes[i] as std::ffi::c_char);
|
||||||
|
let id = Id::uninit(id_bytes);
|
||||||
|
|
||||||
|
let ctx = CudaContext::new(cfg.cuda_device as usize)
|
||||||
|
.map_err(|e| format!("CudaContext::new({}) failed: {e}", cfg.cuda_device))?;
|
||||||
|
let stream = ctx.default_stream();
|
||||||
|
let comm = Comm::from_rank(stream, cfg.rank as usize, cfg.world_size as usize, id)
|
||||||
|
.map_err(|e| {
|
||||||
|
format!(
|
||||||
|
"Comm::from_rank(rank={}, world={}) failed: {e:?}",
|
||||||
|
cfg.rank, cfg.world_size
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
state.ctx = Some(ctx);
|
||||||
|
// `Comm` is !Send + !Sync at the type level because it wraps a
|
||||||
|
// raw `ncclComm_t`. The `Arc` is fine in practice — we
|
||||||
|
// serialise operations through the pool's outer Mutex and the
|
||||||
|
// SendComm wrapper at thread-crossing boundaries enforces this
|
||||||
|
// at every move site. clippy's `arc_with_non_send_sync` lint
|
||||||
|
// can't see that invariant; allow once at the canonical
|
||||||
|
// construction site.
|
||||||
|
#[allow(clippy::arc_with_non_send_sync)]
|
||||||
|
{
|
||||||
|
state.comm = Some(Arc::new(comm));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_sanity_check(comm: &Comm) -> Result<u32, String> {
|
||||||
|
let stream = comm.stream().clone();
|
||||||
|
let input = stream
|
||||||
|
.clone_htod(&[1u32])
|
||||||
|
.map_err(|e| format!("htod sentinel: {e}"))?;
|
||||||
|
let mut output = stream
|
||||||
|
.alloc_zeros::<u32>(1)
|
||||||
|
.map_err(|e| format!("alloc output: {e}"))?;
|
||||||
|
// cudarc::nccl::NcclError doesn't impl Display in 0.19.x —
|
||||||
|
// surface via Debug so we still see the variant + ncclResult
|
||||||
|
// code instead of a generic "{e}" failure.
|
||||||
|
comm.all_reduce(&input, &mut output, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| format!("all_reduce: {e:?}"))?;
|
||||||
|
let result = stream
|
||||||
|
.clone_dtoh(&output)
|
||||||
|
.map_err(|e| format!("dtoh result: {e}"))?;
|
||||||
|
Ok(result[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub use cuda_impl::{NcclState, SendComm, generate_comm_id_hex};
|
||||||
|
|
||||||
|
/// Non-cuda stub for the leader: returns a clear marker error rather
|
||||||
|
/// than letting `init_nccl` succeed vacuously.
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn generate_comm_id_hex() -> Result<String, String> {
|
||||||
|
Err("cuda_feature_not_enabled: build with --features cuda".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_roundtrip() {
|
||||||
|
let original: Vec<u8> = (0u8..=255).collect();
|
||||||
|
let encoded = encode_hex(&original);
|
||||||
|
assert_eq!(encoded.len(), 512);
|
||||||
|
let decoded = decode_hex(&encoded).expect("decode");
|
||||||
|
assert_eq!(decoded, original);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_decode_rejects_odd_length() {
|
||||||
|
assert!(decode_hex("a").is_err());
|
||||||
|
assert!(decode_hex("abc").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_decode_rejects_non_hex() {
|
||||||
|
assert!(decode_hex("zz").is_err());
|
||||||
|
assert!(decode_hex("ab_d").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_encode_is_lowercase_padded() {
|
||||||
|
assert_eq!(encode_hex(&[0x0a, 0xff]), "0aff");
|
||||||
|
}
|
||||||
|
}
|
||||||
257
crates/neuron/src/harness/tp/rpc.rs
Normal file
257
crates/neuron/src/harness/tp/rpc.rs
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
//! Wire protocol between the neuron leader process and its
|
||||||
|
//! `--worker` subprocesses.
|
||||||
|
//!
|
||||||
|
//! Every frame is one newline-delimited JSON object on the worker's
|
||||||
|
//! stdin (request) or stdout (response). Both directions are tagged
|
||||||
|
//! sum types from the start so new ops in Stage 7b/7c slot in without
|
||||||
|
//! breaking compatibility — no "14 message types and a version field"
|
||||||
|
//! drift later. Adding a new variant is the canonical way to evolve
|
||||||
|
//! the protocol; existing peers that don't recognise an op return
|
||||||
|
//! `WorkerResponse::Error { kind: "unknown_op", .. }`.
|
||||||
|
//!
|
||||||
|
//! The serialised shape uses `tag = "op"` so a request looks like:
|
||||||
|
//! {"op":"ping"}
|
||||||
|
//! {"op":"init","comm_id":"a1b2..."}
|
||||||
|
//! and a response:
|
||||||
|
//! {"op":"pong","rank":0,"world_size":2,"cuda_device":0}
|
||||||
|
//! {"op":"error","kind":"nccl_init_failed","message":"..."}
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Leader → worker. Worker handles one at a time; replies with exactly
|
||||||
|
/// one `WorkerResponse` per request.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "op", rename_all = "snake_case")]
|
||||||
|
pub enum WorkerRequest {
|
||||||
|
/// Liveness probe. Worker replies with `Pong` containing its own
|
||||||
|
/// identity. Used by the leader to confirm the subprocess is up
|
||||||
|
/// and ready before kicking off any heavier work.
|
||||||
|
Ping,
|
||||||
|
|
||||||
|
/// One-shot NCCL communicator setup. The leader generates the
|
||||||
|
/// `comm_id` once (rank 0 of NCCL), broadcasts it to every worker
|
||||||
|
/// via this message, then every rank (leader included) calls
|
||||||
|
/// `Comm::from_rank` with the same id — NCCL blocks until all
|
||||||
|
/// `world_size` ranks check in. The hex-encoded bytes are the
|
||||||
|
/// canonical `cudarc::nccl::Id::internal()` content.
|
||||||
|
Init {
|
||||||
|
/// Hex-encoded NCCL id bytes (128 bytes → 256 hex chars).
|
||||||
|
comm_id: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Sanity check: after Init, every rank runs an `all_reduce` over
|
||||||
|
/// a sentinel value (`1u32`). The expected sum is `world_size`.
|
||||||
|
/// Worker replies with the observed value so the leader can verify
|
||||||
|
/// the NCCL handshake is genuinely live, not just configured.
|
||||||
|
NcclSanityCheck,
|
||||||
|
|
||||||
|
/// Load this rank's shard of a dense Qwen3 model from mmaped
|
||||||
|
/// safetensors. The same `safetensors_paths` list is sent to every
|
||||||
|
/// rank — the ShardedVarBuilder reads only the rank-local slice of
|
||||||
|
/// each tensor at materialisation time, so the worker's VRAM
|
||||||
|
/// footprint is `1 / world_size` of the full model (plus replicated
|
||||||
|
/// embedding/norm/lm_head).
|
||||||
|
LoadDenseShard {
|
||||||
|
/// Caller-supplied id for later `GenerateStep` / `UnloadModel`
|
||||||
|
/// lookups. Typically the HF model id verbatim.
|
||||||
|
model_id: String,
|
||||||
|
/// JSON-serialised `candle_transformers::models::qwen3::Config`
|
||||||
|
/// — the same blob the leader parsed from the HF cache's
|
||||||
|
/// `config.json`. Threaded through verbatim so the worker uses
|
||||||
|
/// identical hyperparameters.
|
||||||
|
config_json: String,
|
||||||
|
/// Absolute paths the worker should mmap. The same set on every
|
||||||
|
/// rank; ShardedVarBuilder slices into them per rank.
|
||||||
|
safetensors_paths: Vec<String>,
|
||||||
|
/// Optional in-situ quantization dtype (e.g. "q5k", "q8_0",
|
||||||
|
/// "q6k"). When set, each linear-layer weight is quantized
|
||||||
|
/// at load time to the named ggml format — saves ~3-5x vs
|
||||||
|
/// bf16/f16 at the cost of some accuracy. `None` keeps the
|
||||||
|
/// weights in the on-disk dtype (typically bf16).
|
||||||
|
#[serde(default)]
|
||||||
|
quant: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Run one forward step on this rank's loaded model. The worker
|
||||||
|
/// reaches into its NCCL Comm for the row-parallel `AllReduce`s
|
||||||
|
/// inside the model — and so blocks on every other rank issuing the
|
||||||
|
/// same op. The leader does *not* receive logits back over RPC; it
|
||||||
|
/// runs its own rank-0 forward in parallel and uses its own logits
|
||||||
|
/// for sampling.
|
||||||
|
GenerateStep {
|
||||||
|
model_id: String,
|
||||||
|
/// Input token ids for this step. For prefill, the whole prompt;
|
||||||
|
/// for decode, a single token. Identical on every rank.
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
/// KV cache offset (count of tokens already in the cache before
|
||||||
|
/// this step).
|
||||||
|
offset: usize,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Reset the KV cache for this model on this rank. Sent at the
|
||||||
|
/// start of every inference so a fresh request doesn't accidentally
|
||||||
|
/// attend over the previous one's tokens.
|
||||||
|
ClearKvCache { model_id: String },
|
||||||
|
|
||||||
|
/// Drop this rank's shard for the given model. Releases the VRAM
|
||||||
|
/// the shard's weights occupied; subsequent `GenerateStep` calls
|
||||||
|
/// against the same `model_id` return an `Error`.
|
||||||
|
UnloadModel { model_id: String },
|
||||||
|
|
||||||
|
/// Worker should release resources and exit. Worker replies `Bye`
|
||||||
|
/// and then closes stdout / exits zero. The leader reaps the
|
||||||
|
/// child via the `tokio::process::Child` it kept.
|
||||||
|
Shutdown,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Worker → leader. Always exactly one of these per `WorkerRequest`.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "op", rename_all = "snake_case")]
|
||||||
|
pub enum WorkerResponse {
|
||||||
|
/// Reply to `Ping`. Carries enough identity for the leader to log
|
||||||
|
/// what it actually got back.
|
||||||
|
Pong {
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
cuda_device: u32,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Reply to `Init`. Empty payload — success is the absence of
|
||||||
|
/// `Error`. NCCL's internal blocking handshake means by the time
|
||||||
|
/// this comes back, every other rank has also reached
|
||||||
|
/// `Comm::from_rank`.
|
||||||
|
InitOk,
|
||||||
|
|
||||||
|
/// Reply to `NcclSanityCheck`. The observed sum after a single
|
||||||
|
/// `all_reduce(SUM, 1u32)` across all ranks. The leader checks
|
||||||
|
/// this matches `world_size`.
|
||||||
|
NcclSanityResult { observed_sum: u32 },
|
||||||
|
|
||||||
|
/// Reply to `LoadDenseShard`. Empty payload — success is the
|
||||||
|
/// absence of `Error`. By the time this comes back, the rank's
|
||||||
|
/// `TpQwen3ForCausalLM` is constructed in memory and ready for
|
||||||
|
/// `GenerateStep`.
|
||||||
|
LoadDenseShardOk,
|
||||||
|
|
||||||
|
/// Reply to `GenerateStep`. Empty payload — workers don't ship
|
||||||
|
/// logits over the wire. The leader uses its own rank-0 logits;
|
||||||
|
/// workers only need to confirm the collective completed.
|
||||||
|
GenerateStepOk,
|
||||||
|
|
||||||
|
/// Reply to `ClearKvCache`. Empty payload.
|
||||||
|
KvCacheCleared,
|
||||||
|
|
||||||
|
/// Reply to `UnloadModel`. Empty payload. The named model is no
|
||||||
|
/// longer present on this rank.
|
||||||
|
Unloaded,
|
||||||
|
|
||||||
|
/// Reply to `Shutdown`. Worker exits immediately after writing this.
|
||||||
|
Bye,
|
||||||
|
|
||||||
|
/// Any request can produce this instead of its dedicated success
|
||||||
|
/// variant. `kind` is a machine-readable category so the leader
|
||||||
|
/// can branch on failure mode without string-matching `message`.
|
||||||
|
Error {
|
||||||
|
/// Short tag — `nccl_init_failed`, `unknown_op`, etc.
|
||||||
|
kind: String,
|
||||||
|
/// Human-readable detail for logs.
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn roundtrip<T>(value: &T) -> T
|
||||||
|
where
|
||||||
|
T: Serialize + for<'de> Deserialize<'de>,
|
||||||
|
{
|
||||||
|
serde_json::from_str(&serde_json::to_string(value).expect("serialise"))
|
||||||
|
.expect("deserialise")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_ping_round_trip() {
|
||||||
|
let req = WorkerRequest::Ping;
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert_eq!(wire, r#"{"op":"ping"}"#);
|
||||||
|
match roundtrip(&req) {
|
||||||
|
WorkerRequest::Ping => {}
|
||||||
|
other => panic!("expected Ping, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_init_carries_hex_id() {
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: "deadbeef".into(),
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_shutdown_round_trip() {
|
||||||
|
assert_eq!(
|
||||||
|
serde_json::to_string(&WorkerRequest::Shutdown).unwrap(),
|
||||||
|
r#"{"op":"shutdown"}"#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_pong_round_trip() {
|
||||||
|
let resp = WorkerResponse::Pong {
|
||||||
|
rank: 1,
|
||||||
|
world_size: 4,
|
||||||
|
cuda_device: 1,
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&resp).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"pong""#));
|
||||||
|
assert!(wire.contains(r#""rank":1"#));
|
||||||
|
assert!(wire.contains(r#""world_size":4"#));
|
||||||
|
match roundtrip(&resp) {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
assert_eq!(rank, 1);
|
||||||
|
assert_eq!(world_size, 4);
|
||||||
|
assert_eq!(cuda_device, 1);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_error_carries_kind_and_message() {
|
||||||
|
let resp = WorkerResponse::Error {
|
||||||
|
kind: "nccl_init_failed".into(),
|
||||||
|
message: "could not bind device".into(),
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&resp).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"error""#));
|
||||||
|
assert!(wire.contains(r#""kind":"nccl_init_failed""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_sanity_result_round_trip() {
|
||||||
|
let resp = WorkerResponse::NcclSanityResult { observed_sum: 4 };
|
||||||
|
match roundtrip(&resp) {
|
||||||
|
WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||||
|
assert_eq!(observed_sum, 4);
|
||||||
|
}
|
||||||
|
other => panic!("expected NcclSanityResult, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unknown ops on the wire deserialise to an error rather than
|
||||||
|
/// silently mis-matching — confirms our `serde(tag = "op")`
|
||||||
|
/// configuration rejects unknowns instead of doing fuzzy matching.
|
||||||
|
#[test]
|
||||||
|
fn unknown_op_fails_to_parse() {
|
||||||
|
let result: Result<WorkerRequest, _> = serde_json::from_str(r#"{"op":"explode"}"#);
|
||||||
|
assert!(result.is_err(), "should reject unknown op, got {result:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
283
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
283
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
|
||||||
|
//! and `Shard` sharding hints.
|
||||||
|
//!
|
||||||
|
//! candle reads only the rank's slice of each weight tensor from
|
||||||
|
//! safetensors via `view.slice(start..stop)` — no full-tensor host
|
||||||
|
//! materialisation. That's a memory-efficiency win over hand-rolled
|
||||||
|
//! "load full + narrow" sharding (which the earlier
|
||||||
|
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
|
||||||
|
//!
|
||||||
|
//! Two layer types:
|
||||||
|
//!
|
||||||
|
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
|
||||||
|
//! local matmul. The downstream consumer either accepts a sharded
|
||||||
|
//! activation (next layer is also column-parallel) or all-gathers.
|
||||||
|
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
|
||||||
|
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
|
||||||
|
//!
|
||||||
|
//! Both assume **no bias** — every Qwen3-family weight layout we
|
||||||
|
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
|
||||||
|
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
|
||||||
|
//! support is mechanical when a future model needs it; the design
|
||||||
|
//! choice would be: column-parallel shards the bias along dim 0;
|
||||||
|
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
|
||||||
|
//! sum carries it exactly once.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::quantized::{GgmlDType, QMatMul, QTensor};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use super::all_reduce::AllReduce;
|
||||||
|
|
||||||
|
/// Linear primitive that holds either a plain `Linear` (bf16/f16/f32)
|
||||||
|
/// or a quantized `QMatMul` (Q4K/Q5K/Q6K/Q8_0/etc.).
|
||||||
|
///
|
||||||
|
/// Constructed via [`MaybeQuantLinear::from_weight`] — pass `None` to
|
||||||
|
/// keep the weight in its loaded dtype (no quantization), or
|
||||||
|
/// `Some(dtype)` to quantize at load time.
|
||||||
|
///
|
||||||
|
/// On the forward path the two arms dispatch identically: `Module::forward`
|
||||||
|
/// returns an output in the caller's input dtype (f32 fallback for the
|
||||||
|
/// quantized matmul). Subsequent ops don't need to know whether the
|
||||||
|
/// layer was quantized.
|
||||||
|
pub enum MaybeQuantLinear {
|
||||||
|
Plain(Linear),
|
||||||
|
Quant(QMatMul),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MaybeQuantLinear {
|
||||||
|
/// Build a linear from a loaded weight tensor. If `quant` is set,
|
||||||
|
/// the weight is quantized in-situ and stored as a `QMatMul`;
|
||||||
|
/// otherwise it's wrapped in a plain `Linear`.
|
||||||
|
pub fn from_weight(weight: Tensor, quant: Option<GgmlDType>) -> Result<Self> {
|
||||||
|
match quant {
|
||||||
|
Some(dtype) => {
|
||||||
|
let qt = QTensor::quantize(&weight, dtype).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"QTensor::quantize to {dtype:?} for shape {:?}",
|
||||||
|
weight.shape()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let qmm = QMatMul::from_arc(Arc::new(qt))
|
||||||
|
.context("QMatMul::from_arc on freshly quantized weight")?;
|
||||||
|
Ok(Self::Quant(qmm))
|
||||||
|
}
|
||||||
|
None => Ok(Self::Plain(Linear::new(weight, None))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Above this M (the product of all input dims except the last)
|
||||||
|
/// dispatch the quantized matmul through `QMatMul::forward_via_f16`,
|
||||||
|
/// which dequantizes the weight to f16 once and runs cuBLAS GEMM.
|
||||||
|
/// At or below this M the GGUF GEMV kernel inside
|
||||||
|
/// `QMatMul::forward` wins (it operates on quantized blocks directly
|
||||||
|
/// and accumulates in registers).
|
||||||
|
///
|
||||||
|
/// Empirical: at M=30 on Qwen3.6-27B / RTX 5090, forward_via_f16 was
|
||||||
|
/// slightly *slower* than the GGUF GEMV kernel — the per-call dequant
|
||||||
|
/// cost (~30 MB f16 written to global memory per linear × ~480 calls
|
||||||
|
/// per prefill) eats the cuBLAS GEMM speedup at small M. The
|
||||||
|
/// crossover where the GEMM scaling actually beats the fixed dequant
|
||||||
|
/// tax sits well above M=8.
|
||||||
|
///
|
||||||
|
/// 64 is a conservative crossover that keeps short-prompt prefills
|
||||||
|
/// on the GGUF kernel (where the per-call cost is comparable to the
|
||||||
|
/// f16 path but the dequant tax is zero) and only activates the
|
||||||
|
/// dequant-then-GEMM path for long prefills where the GEMM size
|
||||||
|
/// makes amortising worth it. A proper fix is either a dequant
|
||||||
|
/// cache or a fused dequant+gemm cuda kernel — both larger projects.
|
||||||
|
const QUANT_PREFILL_M_THRESHOLD: usize = 64;
|
||||||
|
|
||||||
|
impl Module for MaybeQuantLinear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Plain(l) => l.forward(x),
|
||||||
|
Self::Quant(qm) => {
|
||||||
|
// Decode vs prefill split. `M` is the "rows of x" the
|
||||||
|
// matmul will iterate over — every dim except the last
|
||||||
|
// (which is in_features). For decode (`seq_len == 1`
|
||||||
|
// with batch 1) M is 1; for prefill with L>>1 it's L
|
||||||
|
// (or B*L).
|
||||||
|
let dims = x.dims();
|
||||||
|
let m: usize = dims.iter().take(dims.len() - 1).product();
|
||||||
|
|
||||||
|
if m > QUANT_PREFILL_M_THRESHOLD {
|
||||||
|
// Prefill: dequantize the weight once into f16,
|
||||||
|
// then run a real cuBLAS-backed GEMM. The cost of
|
||||||
|
// the dequant is amortised across all M tokens.
|
||||||
|
// `forward_via_f16` handles the dtype round-trip
|
||||||
|
// internally (output matches input dtype).
|
||||||
|
return qm.forward_via_f16(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode (M <= threshold): use the on-the-fly GGUF
|
||||||
|
// GEMV kernel via `QMatMul::forward`. That kernel
|
||||||
|
// requires f32 inputs (it accumulates in f32 from the
|
||||||
|
// dequantized quant blocks); cast in/out at the
|
||||||
|
// boundary.
|
||||||
|
let in_dtype = x.dtype();
|
||||||
|
let x_f32 = if in_dtype == candle_core::DType::F32 {
|
||||||
|
x.clone()
|
||||||
|
} else {
|
||||||
|
x.to_dtype(candle_core::DType::F32)?
|
||||||
|
};
|
||||||
|
let y = qm.forward(&x_f32)?;
|
||||||
|
if y.dtype() == in_dtype {
|
||||||
|
Ok(y)
|
||||||
|
} else {
|
||||||
|
y.to_dtype(in_dtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to build a [`Shard`] hint for a given dimension.
|
||||||
|
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
|
||||||
|
Shard {
|
||||||
|
dim,
|
||||||
|
rank: rank as usize,
|
||||||
|
world_size: world_size as usize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Output-dim sharded linear (column-parallel). Holds a
|
||||||
|
/// [`MaybeQuantLinear`] whose underlying weight is this rank's slice
|
||||||
|
/// of the full `[out_features, in_features]` tensor along dim 0.
|
||||||
|
pub struct ColumnParallelLinear {
|
||||||
|
inner: MaybeQuantLinear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ColumnParallelLinear {
|
||||||
|
/// Load this rank's column-parallel slice from a
|
||||||
|
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
|
||||||
|
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
|
||||||
|
///
|
||||||
|
/// Backward-compatible variant — no in-situ quantization. For
|
||||||
|
/// quantized loads, use [`Self::load_with_quant`].
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
Self::load_with_quant(vb, rank, world_size, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Like [`Self::load`] but quantizes the per-rank weight in-situ
|
||||||
|
/// when `quant` is `Some(dtype)`. Saves ~3-5x vs bf16/f16.
|
||||||
|
pub fn load_with_quant(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(0, rank, world_size))
|
||||||
|
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||||
|
.with_context(|| format!("wrap column-parallel '{}'", vb.prefix()))?;
|
||||||
|
Ok(Self { inner })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ColumnParallelLinear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Input-dim sharded linear (row-parallel).
|
||||||
|
///
|
||||||
|
/// Holds a sharded [`MaybeQuantLinear`] plus an `AllReduce` op the
|
||||||
|
/// forward chains after the local matmul to recover the full activation.
|
||||||
|
pub struct RowParallelLinear {
|
||||||
|
inner: MaybeQuantLinear,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
all_reduce: AllReduce,
|
||||||
|
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
|
||||||
|
/// is a pair: the column output is sharded, the row input is
|
||||||
|
/// sharded, and the AllReduce gives back the full output. For
|
||||||
|
/// `world_size = 1` the AllReduce is a no-op so we skip it.
|
||||||
|
needs_reduce: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RowParallelLinear {
|
||||||
|
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
|
||||||
|
///
|
||||||
|
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
|
||||||
|
/// `AllReduce` runs against. On CPU builds the parameter is
|
||||||
|
/// elided — forward returns the partial sum, which is the *wrong*
|
||||||
|
/// answer for inference but lets us compile-check the model.
|
||||||
|
///
|
||||||
|
/// Backward-compatible variant — no in-situ quantization. For
|
||||||
|
/// quantized loads, use [`Self::load_with_quant`].
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::load_with_quant(vb, rank, world_size, comm, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Like [`Self::load`] but quantizes the per-rank weight in-situ.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load_with_quant(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||||
|
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner,
|
||||||
|
all_reduce: AllReduce::new(comm),
|
||||||
|
needs_reduce: world_size > 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
Self::load_with_quant(vb, rank, world_size, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load_with_quant(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||||
|
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner,
|
||||||
|
needs_reduce: world_size > 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for RowParallelLinear {
|
||||||
|
/// Local matmul followed by an `AllReduce` (when `cuda` and
|
||||||
|
/// `world_size > 1`). On CPU or single-rank, returns the partial
|
||||||
|
/// output directly — which is *only* correct for `world_size == 1`.
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let local = self.inner.forward(x)?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
if self.needs_reduce {
|
||||||
|
return local.apply_op1_no_bwd(&self.all_reduce);
|
||||||
|
}
|
||||||
|
let _ = self.needs_reduce;
|
||||||
|
Ok(local)
|
||||||
|
}
|
||||||
|
}
|
||||||
605
crates/neuron/src/harness/tp/tp_qwen3.rs
Normal file
605
crates/neuron/src/harness/tp/tp_qwen3.rs
Normal file
@@ -0,0 +1,605 @@
|
|||||||
|
//! Tensor-parallel Qwen3 dense model.
|
||||||
|
//!
|
||||||
|
//! Mirrors `candle_transformers::models::qwen3` structurally, but with:
|
||||||
|
//!
|
||||||
|
//! - Attention's `q_proj` / `k_proj` / `v_proj` as
|
||||||
|
//! [`ColumnParallelLinear`] (output sharded along the head dimension —
|
||||||
|
//! per-rank `num_heads = total/world_size`, ditto for kv heads).
|
||||||
|
//! - Attention's `o_proj` as [`RowParallelLinear`] (input sharded; the
|
||||||
|
//! trailing `AllReduce` recovers the full activation).
|
||||||
|
//! - MLP's `gate_proj` / `up_proj` as [`ColumnParallelLinear`] (sharded
|
||||||
|
//! along `intermediate_size`).
|
||||||
|
//! - MLP's `down_proj` as [`RowParallelLinear`].
|
||||||
|
//! - `embed_tokens`, all `RmsNorm`s, and `lm_head` **replicated** on
|
||||||
|
//! every rank. The per-rank duplicate weight is bounded and lets us
|
||||||
|
//! skip the embedding all-gather and the lm-head column-shard +
|
||||||
|
//! all-gather; both are pure latency optimisations that don't change
|
||||||
|
//! correctness.
|
||||||
|
//! - `kv_cache` holds the per-rank slice of K/V already (because they
|
||||||
|
//! came out of a column-parallel projection). No cache resharding
|
||||||
|
//! needed across ranks.
|
||||||
|
//!
|
||||||
|
//! Divisibility requirement, checked at load time:
|
||||||
|
//!
|
||||||
|
//! - `num_attention_heads % world_size == 0`
|
||||||
|
//! - `num_key_value_heads % world_size == 0`
|
||||||
|
//! - `intermediate_size % world_size == 0`
|
||||||
|
//!
|
||||||
|
//! Anything else bails — the safetensors slice would lose data otherwise.
|
||||||
|
//! This is the same divisibility-bail pattern that landed in
|
||||||
|
//! `EricLBuehler/mistral.rs` PR #2054.
|
||||||
|
//!
|
||||||
|
//! Replicated tensors (norms, embedding, lm_head) are loaded by asking
|
||||||
|
//! the `ShardedVarBuilder` for the full tensor via `vb.get(shape, name)`
|
||||||
|
//! — which defaults to `Shard { world_size: 1 }` and falls through to
|
||||||
|
//! the unsharded backend path.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, bail};
|
||||||
|
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use candle_nn::{Activation, Embedding, Linear, RmsNorm, kv_cache::ConcatKvCache};
|
||||||
|
use candle_transformers::utils::repeat_kv;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use cudarc::nccl::Comm;
|
||||||
|
|
||||||
|
use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
|
||||||
|
|
||||||
|
pub use candle_transformers::models::qwen3::Config;
|
||||||
|
|
||||||
|
/// Replicated rotary-embedding lookup. Re-implementation of the
|
||||||
|
/// `pub(crate)` candle equivalent — we can't reach into the upstream
|
||||||
|
/// type, so the inv-freq / sin / cos construction lives here.
|
||||||
|
pub(crate) struct Qwen3RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3RotaryEmbedding {
|
||||||
|
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let dim = cfg.head_dim;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let (_, _, seq_len, _) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper: load a replicated tensor by asking the ShardedVarBuilder for
|
||||||
|
/// the full tensor (world_size=1 hint falls through to SimpleBackend).
|
||||||
|
fn load_replicated<S: Into<candle_core::Shape>>(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
shape: S,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
vb.get(shape, name)
|
||||||
|
.with_context(|| format!("load replicated '{}/{name}'", vb.prefix()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_rms_norm(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<RmsNorm> {
|
||||||
|
let weight = load_replicated(vb, size, "weight")?;
|
||||||
|
Ok(RmsNorm::new(weight, eps))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP MLP. SwiGLU = `down(silu(gate(x)) * up(x))`.
|
||||||
|
pub(crate) struct TpQwen3MLP {
|
||||||
|
gate_proj: ColumnParallelLinear,
|
||||||
|
up_proj: ColumnParallelLinear,
|
||||||
|
down_proj: RowParallelLinear,
|
||||||
|
act_fn: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3MLP {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||||
|
bail!(
|
||||||
|
"intermediate_size {} not divisible by world_size {}",
|
||||||
|
cfg.intermediate_size,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||||
|
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||||
|
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||||
|
bail!(
|
||||||
|
"intermediate_size {} not divisible by world_size {}",
|
||||||
|
cfg.intermediate_size,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||||
|
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||||
|
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TpQwen3MLP {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = x.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP attention. Carries per-rank head counts and the q/k per-head
|
||||||
|
/// RmsNorms (which are replicated and operate on a flattened B*H*L
|
||||||
|
/// axis, so the same code path works irrespective of how H was split).
|
||||||
|
pub(crate) struct TpQwen3Attention {
|
||||||
|
q_proj: ColumnParallelLinear,
|
||||||
|
k_proj: ColumnParallelLinear,
|
||||||
|
v_proj: ColumnParallelLinear,
|
||||||
|
o_proj: RowParallelLinear,
|
||||||
|
q_norm: RmsNorm,
|
||||||
|
k_norm: RmsNorm,
|
||||||
|
local_num_heads: usize,
|
||||||
|
local_num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
local_hidden_size: usize,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
kv_cache: ConcatKvCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3Attention {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::load_inner(
|
||||||
|
cfg,
|
||||||
|
rotary_emb,
|
||||||
|
vb,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
comm,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::load_inner(cfg, rotary_emb, vb, rank, world_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_inner(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if cfg.use_sliding_window {
|
||||||
|
bail!("sliding window is not yet supported in the TP path");
|
||||||
|
}
|
||||||
|
if cfg.attention_bias {
|
||||||
|
bail!("attention_bias=true is not supported by ColumnParallel/RowParallelLinear yet");
|
||||||
|
}
|
||||||
|
let ws = world_size as usize;
|
||||||
|
if !cfg.num_attention_heads.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"num_attention_heads {} not divisible by world_size {}",
|
||||||
|
cfg.num_attention_heads,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !cfg.num_key_value_heads.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"num_key_value_heads {} not divisible by world_size {}",
|
||||||
|
cfg.num_key_value_heads,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let local_num_heads = cfg.num_attention_heads / ws;
|
||||||
|
let local_num_kv_heads = cfg.num_key_value_heads / ws;
|
||||||
|
let num_kv_groups = local_num_heads / local_num_kv_heads;
|
||||||
|
let local_hidden_size = head_dim * local_num_heads;
|
||||||
|
|
||||||
|
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?;
|
||||||
|
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?;
|
||||||
|
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?;
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?;
|
||||||
|
|
||||||
|
let q_norm = load_rms_norm(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
let k_norm = load_rms_norm(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
// dim=2 because we cat along the seq axis of (B, H, L, D) tensors.
|
||||||
|
let kv_cache = ConcatKvCache::new(2);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
local_num_heads,
|
||||||
|
local_num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
local_hidden_size,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
|
// 1. Projections (column-parallel → output is sharded).
|
||||||
|
let q = self.q_proj.forward(x)?;
|
||||||
|
let k = self.k_proj.forward(x)?;
|
||||||
|
let v = self.v_proj.forward(x)?;
|
||||||
|
|
||||||
|
// 2. Reshape: (B, L, H, D) → (B, H, L, D).
|
||||||
|
let q = q
|
||||||
|
.reshape((b, l, self.local_num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let k = k
|
||||||
|
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = v
|
||||||
|
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
// 3. Per-head RmsNorm (replicated weight, flat input).
|
||||||
|
let q_flat = q.flatten(0, 2)?;
|
||||||
|
let k_flat = k.flatten(0, 2)?;
|
||||||
|
let q_flat = self.q_norm.forward(&q_flat)?;
|
||||||
|
let k_flat = self.k_norm.forward(&k_flat)?;
|
||||||
|
let q = q_flat.reshape((b, self.local_num_heads, l, self.head_dim))?;
|
||||||
|
let k = k_flat.reshape((b, self.local_num_kv_heads, l, self.head_dim))?;
|
||||||
|
|
||||||
|
// 4. Rotary.
|
||||||
|
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
|
||||||
|
|
||||||
|
// 5. Accumulate KV.
|
||||||
|
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||||
|
|
||||||
|
// 6. GQA repeat_kv on the rank-local K/V.
|
||||||
|
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||||
|
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
|
// 7. Attention scores.
|
||||||
|
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||||
|
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
if let Some(m) = attn_mask {
|
||||||
|
scores = scores.broadcast_add(m)?;
|
||||||
|
}
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
let ctx = probs.matmul(&v)?;
|
||||||
|
|
||||||
|
// 8. Output projection (row-parallel → AllReduce inside).
|
||||||
|
ctx.transpose(1, 2)?
|
||||||
|
.reshape((b, l, self.local_hidden_size))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TpDecoderLayer {
|
||||||
|
self_attn: TpQwen3Attention,
|
||||||
|
mlp: TpQwen3MLP,
|
||||||
|
ln1: RmsNorm,
|
||||||
|
ln2: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpDecoderLayer {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn = TpQwen3Attention::load(
|
||||||
|
cfg,
|
||||||
|
rotary_emb,
|
||||||
|
&vb.pp("self_attn"),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
comm.clone(),
|
||||||
|
)?;
|
||||||
|
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?;
|
||||||
|
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
let ln2 = load_rms_norm(
|
||||||
|
&vb.pp("post_attention_layernorm"),
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn =
|
||||||
|
TpQwen3Attention::load(cfg, rotary_emb, &vb.pp("self_attn"), rank, world_size)?;
|
||||||
|
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?;
|
||||||
|
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
let ln2 = load_rms_norm(
|
||||||
|
&vb.pp("post_attention_layernorm"),
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let h = self.ln1.forward(x)?;
|
||||||
|
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||||
|
let x = (x + h)?;
|
||||||
|
let h2 = self.ln2.forward(&x)?;
|
||||||
|
let h2 = h2.apply(&self.mlp)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Base TP Qwen3 transformer — embedding, decoder stack, final norm.
|
||||||
|
/// The lm_head sits on top in [`TpQwen3ForCausalLM`].
|
||||||
|
pub struct TpQwen3Model {
|
||||||
|
embed_tokens: Embedding,
|
||||||
|
layers: Vec<TpDecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3Model {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||||
|
|
||||||
|
let embed_vb = vb.pp("model.embed_tokens");
|
||||||
|
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
|
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let vb_l = vb.pp("model.layers");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(TpDecoderLayer::load(
|
||||||
|
cfg,
|
||||||
|
rotary.clone(),
|
||||||
|
&vb_l.pp(i),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
comm.clone(),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||||
|
|
||||||
|
let embed_vb = vb.pp("model.embed_tokens");
|
||||||
|
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
|
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let vb_l = vb.pp("model.layers");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(TpDecoderLayer::load(
|
||||||
|
cfg,
|
||||||
|
rotary.clone(),
|
||||||
|
&vb_l.pp(i),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_weight(&self) -> &Tensor {
|
||||||
|
self.embed_tokens.embeddings()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for l in &mut self.layers {
|
||||||
|
l.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
|
};
|
||||||
|
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP Qwen3 with a (replicated) language-model head on top.
|
||||||
|
pub struct TpQwen3ForCausalLM {
|
||||||
|
base: TpQwen3Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3ForCausalLM {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let base = TpQwen3Model::load(cfg, vb, rank, world_size, comm)?;
|
||||||
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let base = TpQwen3Model::load(cfg, vb, rank, world_size)?;
|
||||||
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
let hidden = self.base.forward(input, offset)?;
|
||||||
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base.clear_kv_cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &Device {
|
||||||
|
&self.base.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> DType {
|
||||||
|
self.base.dtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_lm_head(cfg: &Config, vb: &ShardedVarBuilder, base: &TpQwen3Model) -> Result<Linear> {
|
||||||
|
if cfg.tie_word_embeddings {
|
||||||
|
Ok(Linear::new(base.embed_weight().clone(), None))
|
||||||
|
} else {
|
||||||
|
let weight = load_replicated(
|
||||||
|
&vb.pp("lm_head"),
|
||||||
|
(cfg.vocab_size, cfg.hidden_size),
|
||||||
|
"weight",
|
||||||
|
)?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
|
}
|
||||||
1131
crates/neuron/src/harness/tp/tp_qwen3_5.rs
Normal file
1131
crates/neuron/src/harness/tp/tp_qwen3_5.rs
Normal file
File diff suppressed because it is too large
Load Diff
502
crates/neuron/src/harness/tp/worker.rs
Normal file
502
crates/neuron/src/harness/tp/worker.rs
Normal file
@@ -0,0 +1,502 @@
|
|||||||
|
//! Entry point for `neuron --worker`.
|
||||||
|
//!
|
||||||
|
//! The worker reads one newline-delimited JSON `WorkerRequest` from
|
||||||
|
//! stdin per loop iteration, dispatches synchronously, and writes
|
||||||
|
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
|
||||||
|
//! stderr so it doesn't collide with the RPC stream.
|
||||||
|
//!
|
||||||
|
//! NCCL operations (`Init`, `NcclSanityCheck`) and model lifecycle ops
|
||||||
|
//! (`LoadDenseShard`, `GenerateStep`, `ClearKvCache`, `UnloadModel`)
|
||||||
|
//! are real when built with the `cuda` feature; without it they reply
|
||||||
|
//! with `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
|
||||||
|
//! the difference between a misconfigured build and a genuine NCCL or
|
||||||
|
//! model failure.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
|
||||||
|
use super::nccl_state::NcclState;
|
||||||
|
use super::rpc::{WorkerRequest, WorkerResponse};
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use super::tp_qwen3::TpQwen3ForCausalLM;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use super::tp_qwen3_5::TpQwen3_5ForCausalLM;
|
||||||
|
|
||||||
|
/// Worker-side discriminator over the architectures we can load via
|
||||||
|
/// `LoadDenseShard`. Mirrors `super::TpLeaderModel` on the leader
|
||||||
|
/// side — the dispatch happens on the `model_type` extracted from the
|
||||||
|
/// config JSON.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
enum WorkerModel {
|
||||||
|
Qwen3(TpQwen3ForCausalLM),
|
||||||
|
Qwen3_5(TpQwen3_5ForCausalLM),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
impl WorkerModel {
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
input: &candle_core::Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
WorkerModel::Qwen3(m) => m.forward(input, offset),
|
||||||
|
WorkerModel::Qwen3_5(m) => m.forward(input, offset),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
match self {
|
||||||
|
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
|
||||||
|
WorkerModel::Qwen3_5(m) => m.clear_kv_cache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> &candle_core::Device {
|
||||||
|
match self {
|
||||||
|
WorkerModel::Qwen3(m) => m.device(),
|
||||||
|
WorkerModel::Qwen3_5(m) => m.device(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct WorkerConfig {
|
||||||
|
pub rank: u32,
|
||||||
|
pub world_size: u32,
|
||||||
|
pub cuda_device: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drive the worker RPC loop until `Shutdown` or EOF on stdin.
|
||||||
|
pub async fn run(config: WorkerConfig) -> Result<()> {
|
||||||
|
tracing::info!(
|
||||||
|
rank = config.rank,
|
||||||
|
world_size = config.world_size,
|
||||||
|
cuda_device = config.cuda_device,
|
||||||
|
"tp worker starting"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut state = WorkerState::new(config);
|
||||||
|
let stdin = tokio::io::stdin();
|
||||||
|
let mut reader = BufReader::new(stdin).lines();
|
||||||
|
let mut stdout = tokio::io::stdout();
|
||||||
|
|
||||||
|
while let Some(line) = reader.next_line().await? {
|
||||||
|
if line.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let req: WorkerRequest = match serde_json::from_str(&line) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let resp = WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse {line:?}: {e}"),
|
||||||
|
};
|
||||||
|
write_response(&mut stdout, &resp).await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = state.handle(req).await;
|
||||||
|
let is_bye = matches!(resp, WorkerResponse::Bye);
|
||||||
|
write_response(&mut stdout, &resp).await?;
|
||||||
|
if is_bye {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(rank = config.rank, "tp worker exiting");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -> Result<()> {
|
||||||
|
let mut line = serde_json::to_string(resp)?;
|
||||||
|
line.push('\n');
|
||||||
|
stdout.write_all(line.as_bytes()).await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One rank's local state. Owns the rank's NCCL communicator (via
|
||||||
|
/// `NcclState`) and the rank's shard of every loaded model.
|
||||||
|
struct WorkerState {
|
||||||
|
config: WorkerConfig,
|
||||||
|
nccl: NcclState,
|
||||||
|
/// Loaded model shards keyed by `model_id`. Each entry wraps the
|
||||||
|
/// rank's TP architecture handle (Qwen3 or Qwen3-Next) — the
|
||||||
|
/// column/row-parallel layers hold an `Arc<Comm>` cloned from
|
||||||
|
/// `nccl`. Cuda-only: the underlying types reference cudarc types
|
||||||
|
/// that don't exist without the cuda feature.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
models: HashMap<String, WorkerModel>,
|
||||||
|
/// Placeholder so the non-cuda build keeps the same field name set
|
||||||
|
/// and `WorkerState::new` reads the same on both.
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
models: HashMap<String, ()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerState {
|
||||||
|
fn new(config: WorkerConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
nccl: NcclState::new(),
|
||||||
|
models: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(&mut self, req: WorkerRequest) -> WorkerResponse {
|
||||||
|
match req {
|
||||||
|
WorkerRequest::Ping => WorkerResponse::Pong {
|
||||||
|
rank: self.config.rank,
|
||||||
|
world_size: self.config.world_size,
|
||||||
|
cuda_device: self.config.cuda_device,
|
||||||
|
},
|
||||||
|
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
|
||||||
|
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
|
||||||
|
WorkerRequest::LoadDenseShard {
|
||||||
|
model_id,
|
||||||
|
config_json,
|
||||||
|
safetensors_paths,
|
||||||
|
quant,
|
||||||
|
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant),
|
||||||
|
WorkerRequest::GenerateStep {
|
||||||
|
model_id,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
} => self.handle_generate_step(&model_id, tokens, offset),
|
||||||
|
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
||||||
|
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
||||||
|
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_load_dense_shard(
|
||||||
|
&mut self,
|
||||||
|
model_id: String,
|
||||||
|
config_json: String,
|
||||||
|
safetensors_paths: Vec<String>,
|
||||||
|
quant: Option<String>,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
use crate::harness::arch::qwen3_5 as qwen3_5_arch;
|
||||||
|
use candle_core::{DType, Device};
|
||||||
|
use candle_nn::var_builder::ShardedSafeTensors;
|
||||||
|
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
let quant_dtype = match parse_quant_string(quant.as_deref()) {
|
||||||
|
Ok(q) => q,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse quant: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.models.contains_key(&model_id) {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "already_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' already loaded on this rank"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
let comm = match self.nccl.comm() {
|
||||||
|
Some(c) => c,
|
||||||
|
None => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "nccl_not_initialised".into(),
|
||||||
|
message: "LoadDenseShard requires Init to have completed first".into(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Peek at model_type so we know which architecture to build.
|
||||||
|
let model_type = serde_json::from_str::<serde_json::Value>(&config_json)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.get("model_type"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let device = match Device::new_cuda(self.config.cuda_device as usize) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "cuda_unavailable".into(),
|
||||||
|
message: format!("Device::new_cuda({}) failed: {e}", self.config.cuda_device),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let paths: Vec<PathBuf> = safetensors_paths.into_iter().map(PathBuf::from).collect();
|
||||||
|
// SAFETY: same invariant as the single-GPU dense path — the HF
|
||||||
|
// cache files are treated as immutable while the mmap is held.
|
||||||
|
let vb = match unsafe { ShardedSafeTensors::var_builder(&paths, DType::BF16, &device) } {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("ShardedSafeTensors::var_builder: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Separate mmap of the same paths for the direct fused-region
|
||||||
|
// loader in `fused_load`. Linux's page cache shares the
|
||||||
|
// underlying pages between the two mmaps; the cost is one
|
||||||
|
// extra set of safetensors-header parses.
|
||||||
|
let mmap = match unsafe { candle_core::safetensors::MmapedSafetensors::multi(&paths) } {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("MmapedSafetensors::multi: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let loaded = match model_type.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse Qwen3 Config JSON: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match TpQwen3ForCausalLM::load(
|
||||||
|
&cfg,
|
||||||
|
&vb,
|
||||||
|
self.config.rank,
|
||||||
|
self.config.world_size,
|
||||||
|
comm,
|
||||||
|
) {
|
||||||
|
Ok(m) => WorkerModel::Qwen3(m),
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("TpQwen3ForCausalLM::load: {e:#}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"qwen3_5" => {
|
||||||
|
let cfg: qwen3_5_arch::Config = match serde_json::from_str(&config_json) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse Qwen3-Next Config JSON: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match TpQwen3_5ForCausalLM::load(
|
||||||
|
cfg,
|
||||||
|
&vb,
|
||||||
|
&mmap,
|
||||||
|
self.config.rank,
|
||||||
|
self.config.world_size,
|
||||||
|
comm,
|
||||||
|
quant_dtype,
|
||||||
|
) {
|
||||||
|
Ok(m) => WorkerModel::Qwen3_5(m),
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("TpQwen3_5ForCausalLM::load: {e:#}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "unsupported_arch".into(),
|
||||||
|
message: format!(
|
||||||
|
"worker: unsupported model_type '{other}' (supported: qwen3, qwen3_5)"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.models.insert(model_id.clone(), loaded);
|
||||||
|
tracing::info!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
model_type = %model_type,
|
||||||
|
"loaded TP shard"
|
||||||
|
);
|
||||||
|
WorkerResponse::LoadDenseShardOk
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_load_dense_shard(
|
||||||
|
&mut self,
|
||||||
|
_model_id: String,
|
||||||
|
_config_json: String,
|
||||||
|
_safetensors_paths: Vec<String>,
|
||||||
|
_quant: Option<String>,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "LoadDenseShard requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_generate_step(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
use candle_core::Tensor;
|
||||||
|
|
||||||
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
let device = model.device().clone();
|
||||||
|
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("build input tensor: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens.len(),
|
||||||
|
offset,
|
||||||
|
"worker GenerateStep: forward starting"
|
||||||
|
);
|
||||||
|
// Drop the resulting logits — the leader uses its own copy from
|
||||||
|
// rank 0. The forward's value here is the NCCL collectives it
|
||||||
|
// issues, which let the leader's rank-0 forward make progress.
|
||||||
|
if let Err(e) = model.forward(&input, offset) {
|
||||||
|
tracing::warn!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
error = %e,
|
||||||
|
"worker GenerateStep: forward failed"
|
||||||
|
);
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("TP forward: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
"worker GenerateStep: forward done"
|
||||||
|
);
|
||||||
|
WorkerResponse::GenerateStepOk
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_generate_step(
|
||||||
|
&mut self,
|
||||||
|
_model_id: &str,
|
||||||
|
_tokens: Vec<u32>,
|
||||||
|
_offset: usize,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "GenerateStep requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
|
||||||
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
model.clear_kv_cache();
|
||||||
|
WorkerResponse::KvCacheCleared
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_clear_kv_cache(&mut self, _model_id: &str) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "ClearKvCache requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_unload_model(&mut self, model_id: &str) -> WorkerResponse {
|
||||||
|
if self.models.remove(model_id).is_none() {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
tracing::info!(rank = self.config.rank, model = %model_id, "unloaded TP shard");
|
||||||
|
WorkerResponse::Unloaded
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_unload_model(&mut self, _model_id: &str) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "UnloadModel requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the
|
||||||
|
/// common ggml format names (case-insensitive). `None` and `Some("")`
|
||||||
|
/// both map to "no quantization".
|
||||||
|
///
|
||||||
|
/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`,
|
||||||
|
/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`,
|
||||||
|
/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore
|
||||||
|
/// is optional and the prefix is case-insensitive.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub(crate) fn parse_quant_string(
|
||||||
|
s: Option<&str>,
|
||||||
|
) -> anyhow::Result<Option<candle_core::quantized::GgmlDType>> {
|
||||||
|
use candle_core::quantized::GgmlDType;
|
||||||
|
let s = match s {
|
||||||
|
Some(s) if !s.is_empty() => s,
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
let normalised = s.to_ascii_lowercase().replace('_', "");
|
||||||
|
let dtype = match normalised.as_str() {
|
||||||
|
"q40" => GgmlDType::Q4_0,
|
||||||
|
"q41" => GgmlDType::Q4_1,
|
||||||
|
"q50" => GgmlDType::Q5_0,
|
||||||
|
"q51" => GgmlDType::Q5_1,
|
||||||
|
"q80" => GgmlDType::Q8_0,
|
||||||
|
"q81" => GgmlDType::Q8_1,
|
||||||
|
"q2k" => GgmlDType::Q2K,
|
||||||
|
"q3k" => GgmlDType::Q3K,
|
||||||
|
"q4k" | "q4km" => GgmlDType::Q4K,
|
||||||
|
"q5k" | "q5km" => GgmlDType::Q5K,
|
||||||
|
"q6k" => GgmlDType::Q6K,
|
||||||
|
"q8k" => GgmlDType::Q8K,
|
||||||
|
"f16" => GgmlDType::F16,
|
||||||
|
"bf16" => GgmlDType::BF16,
|
||||||
|
"f32" => GgmlDType::F32,
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \
|
||||||
|
q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
Ok(Some(dtype))
|
||||||
|
}
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
pub mod api;
|
pub mod api;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod cuda;
|
||||||
pub mod discovery;
|
pub mod discovery;
|
||||||
pub mod harness;
|
pub mod harness;
|
||||||
pub mod health;
|
pub mod health;
|
||||||
|
pub mod startup;
|
||||||
|
|||||||
@@ -1,21 +1,66 @@
|
|||||||
use anyhow::Result;
|
use anyhow::{Context, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health};
|
use neuron::{
|
||||||
|
api,
|
||||||
|
config::NeuronConfig,
|
||||||
|
discovery,
|
||||||
|
harness::{HarnessRegistry, tp},
|
||||||
|
health, startup,
|
||||||
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
/// Top-level CLI. The same binary runs as either the public neuron
|
||||||
|
/// daemon (default), a tensor-parallel worker subprocess (when
|
||||||
|
/// `--worker` is set, spawned by the leader on the same host), or a
|
||||||
|
/// one-shot TP NCCL handshake check (when `--tp-smoke` is set).
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "neuron")]
|
#[command(name = "neuron")]
|
||||||
#[command(about = "Per-node daemon for cortex inference clusters")]
|
#[command(about = "Per-node daemon for cortex inference clusters")]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Port to listen on (overrides config file).
|
/// Run in tensor-parallel worker mode. The leader process spawns
|
||||||
|
/// one of these per non-zero NCCL rank and drives it over
|
||||||
|
/// newline-delimited JSON on stdin/stdout. Worker mode skips
|
||||||
|
/// discovery, the HTTP listener, and the health poller — it's a
|
||||||
|
/// pure RPC loop.
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
worker: bool,
|
||||||
|
|
||||||
|
/// Run a one-shot TP smoke test: spawn `--tp-size - 1` worker
|
||||||
|
/// subprocesses on `--cuda-devices`, build the NCCL communicator,
|
||||||
|
/// run an `AllReduce` sanity check across every rank, and exit.
|
||||||
|
/// Used to validate the TP plumbing in isolation from model load
|
||||||
|
/// and inference. Diagnostic-only — not exposed through the daemon
|
||||||
|
/// HTTP API.
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
tp_smoke: bool,
|
||||||
|
|
||||||
|
/// NCCL rank for worker mode. Ignored when `--worker` is not set.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
rank: u32,
|
||||||
|
|
||||||
|
/// Total NCCL world size for worker mode or TP smoke mode.
|
||||||
|
#[arg(long, default_value_t = 1)]
|
||||||
|
tp_size: u32,
|
||||||
|
|
||||||
|
/// CUDA device index for worker mode. Ignored when `--worker` is
|
||||||
|
/// not set.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
cuda_device: u32,
|
||||||
|
|
||||||
|
/// Comma-separated CUDA device indices for TP smoke mode (one per
|
||||||
|
/// rank, starting with rank 0). Must have `tp_size` entries.
|
||||||
|
#[arg(long, value_delimiter = ',')]
|
||||||
|
cuda_devices: Vec<u32>,
|
||||||
|
|
||||||
|
/// Port to listen on (overrides config file). Daemon mode only.
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
port: Option<u16>,
|
port: Option<u16>,
|
||||||
|
|
||||||
/// Path to the neuron config file.
|
/// Path to the neuron config file. Daemon mode only.
|
||||||
#[arg(short, long, default_value = "neuron.toml")]
|
#[arg(short, long, default_value = "neuron.toml")]
|
||||||
config: String,
|
config: String,
|
||||||
}
|
}
|
||||||
@@ -23,6 +68,7 @@ struct Args {
|
|||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
.with_env_filter(
|
.with_env_filter(
|
||||||
EnvFilter::try_from_default_env()
|
EnvFilter::try_from_default_env()
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
|
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
|
||||||
@@ -31,12 +77,85 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
if args.worker {
|
||||||
|
return tp::worker::run(tp::worker::WorkerConfig {
|
||||||
|
rank: args.rank,
|
||||||
|
world_size: args.tp_size,
|
||||||
|
cuda_device: args.cuda_device,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.tp_smoke {
|
||||||
|
return tp_smoke(args.tp_size, args.cuda_devices).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
daemon(args).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One-shot tensor-parallel handshake. Spawns N-1 worker subprocesses
|
||||||
|
/// (rank 0 stays in this process), builds the NCCL communicator across
|
||||||
|
/// the full world, runs an AllReduce sanity check, and shuts everyone
|
||||||
|
/// down. Output is plain log lines on stderr + a final summary on
|
||||||
|
/// stdout in `key=value` form so an outer script can parse it.
|
||||||
|
async fn tp_smoke(tp_size: u32, cuda_devices: Vec<u32>) -> Result<()> {
|
||||||
|
if tp_size < 2 {
|
||||||
|
anyhow::bail!("--tp-size must be at least 2 (got {tp_size})");
|
||||||
|
}
|
||||||
|
if cuda_devices.len() as u32 != tp_size {
|
||||||
|
anyhow::bail!(
|
||||||
|
"--cuda-devices must list exactly {tp_size} entries (got {})",
|
||||||
|
cuda_devices.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
||||||
|
let leader_device = cuda_devices[0];
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
tp_size,
|
||||||
|
?cuda_devices,
|
||||||
|
binary = %exe.display(),
|
||||||
|
"tp-smoke: spawning worker pool"
|
||||||
|
);
|
||||||
|
let mut pool = tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices).await?;
|
||||||
|
|
||||||
|
tracing::info!("tp-smoke: pinging every worker");
|
||||||
|
let pongs = pool.ping_all().await?;
|
||||||
|
for p in &pongs {
|
||||||
|
tracing::info!(?p, "tp-smoke: pong");
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(leader_device, "tp-smoke: initialising NCCL");
|
||||||
|
pool.init_nccl(leader_device).await?;
|
||||||
|
|
||||||
|
tracing::info!("tp-smoke: running AllReduce sanity check");
|
||||||
|
pool.nccl_sanity_check().await?;
|
||||||
|
|
||||||
|
tracing::info!("tp-smoke: shutting down pool");
|
||||||
|
pool.shutdown().await?;
|
||||||
|
|
||||||
|
println!("status=ok");
|
||||||
|
println!("tp_size={tp_size}");
|
||||||
|
println!(
|
||||||
|
"cuda_devices={}",
|
||||||
|
cuda_devices
|
||||||
|
.iter()
|
||||||
|
.map(|d| d.to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(",")
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn daemon(args: Args) -> Result<()> {
|
||||||
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
|
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
|
||||||
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
|
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
|
||||||
NeuronConfig::default()
|
NeuronConfig::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
let port = args.port.unwrap_or(cfg.port);
|
let port = args.port.unwrap_or(cfg.port);
|
||||||
|
let bind_url = format!("http://localhost:{port}");
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
tracing::info!("running hardware discovery");
|
tracing::info!("running hardware discovery");
|
||||||
@@ -47,9 +166,18 @@ async fn main() -> Result<()> {
|
|||||||
"discovery complete"
|
"discovery complete"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Build harness registry from config.
|
// Build harness registry from config. In-process harnesses (candle)
|
||||||
let registry = HarnessRegistry::from_configs(&cfg.harnesses);
|
// need to know neuron's own bind URL so they can return it from
|
||||||
|
// inference_endpoint.
|
||||||
|
let registry = HarnessRegistry::from_configs(&cfg.harnesses, &bind_url, &cfg.harness);
|
||||||
discovery_result.harnesses = registry.names();
|
discovery_result.harnesses = registry.names();
|
||||||
|
let candle = registry.candle();
|
||||||
|
|
||||||
|
// Activation: load default models before binding the listener.
|
||||||
|
// Each load may take tens of seconds to several minutes depending
|
||||||
|
// on model size and HF cache state — keep TimeoutStartSec in the
|
||||||
|
// systemd unit generous enough to cover the slowest entry.
|
||||||
|
startup::load_default_models(®istry, &cfg.default_models).await;
|
||||||
|
|
||||||
let health_cache = Arc::new(health::HealthCache::new());
|
let health_cache = Arc::new(health::HealthCache::new());
|
||||||
health_cache
|
health_cache
|
||||||
@@ -65,13 +193,24 @@ async fn main() -> Result<()> {
|
|||||||
discovery: discovery_result,
|
discovery: discovery_result,
|
||||||
health_cache,
|
health_cache,
|
||||||
registry: RwLock::new(registry),
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(Arc::clone(&state));
|
||||||
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
|
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
|
||||||
tracing::info!("neuron listening on {addr}");
|
tracing::info!("neuron listening on {addr}");
|
||||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
axum::serve(listener, app).await?;
|
axum::serve(listener, app)
|
||||||
|
.with_graceful_shutdown(startup::shutdown_signal())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Deactivation: serve has returned (graceful shutdown signal
|
||||||
|
// received and connections drained). Release CUDA contexts / VRAM
|
||||||
|
// by unloading every model before exiting; systemd's TimeoutStopSec
|
||||||
|
// bounds how long this phase may take.
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
startup::unload_all_models(®istry).await;
|
||||||
|
tracing::info!("shutdown complete");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
97
crates/neuron/src/startup.rs
Normal file
97
crates/neuron/src/startup.rs
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
//! Activation- and deactivation-time orchestration.
|
||||||
|
//!
|
||||||
|
//! Wired from `main.rs` around the HTTP listener — activation runs
|
||||||
|
//! before bind, deactivation runs after axum returns from its
|
||||||
|
//! graceful-shutdown future. Kept in its own module so the logic is
|
||||||
|
//! unit-testable without spinning up a full neuron process.
|
||||||
|
|
||||||
|
use crate::harness::HarnessRegistry;
|
||||||
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tokio::signal;
|
||||||
|
|
||||||
|
/// Load each spec sequentially against the registry, treating
|
||||||
|
/// individual failures as warnings rather than fatal errors.
|
||||||
|
///
|
||||||
|
/// VRAM contention makes parallel loads risky; the sequential path is
|
||||||
|
/// boring but correct. The function logs elapsed time per load so an
|
||||||
|
/// operator can see which model is hogging activation.
|
||||||
|
pub async fn load_default_models(registry: &HarnessRegistry, specs: &[ModelSpec]) {
|
||||||
|
if specs.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tracing::info!(count = specs.len(), "loading default models");
|
||||||
|
for spec in specs {
|
||||||
|
let start = Instant::now();
|
||||||
|
match registry.load_model(spec).await {
|
||||||
|
Ok(()) => tracing::info!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"loaded default model"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::warn!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
error = %e,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"failed to load default model, continuing"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Future that resolves on SIGINT (Ctrl-C) or SIGTERM (systemd stop).
|
||||||
|
///
|
||||||
|
/// Wired into `axum::serve(...).with_graceful_shutdown(shutdown_signal())`
|
||||||
|
/// so the HTTP listener stops accepting new connections, lets in-flight
|
||||||
|
/// requests drain, and then yields control back to main for cleanup.
|
||||||
|
pub async fn shutdown_signal() {
|
||||||
|
let ctrl_c = async {
|
||||||
|
signal::ctrl_c().await.ok();
|
||||||
|
};
|
||||||
|
let terminate = async {
|
||||||
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||||
|
.expect("install SIGTERM handler")
|
||||||
|
.recv()
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
tokio::select! {
|
||||||
|
_ = ctrl_c => tracing::info!("received SIGINT, shutting down"),
|
||||||
|
_ = terminate => tracing::info!("received SIGTERM, shutting down"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unload every model currently registered. Called from `main.rs` after
|
||||||
|
/// axum's graceful shutdown future resolves, so CUDA contexts and VRAM
|
||||||
|
/// are released before the process exits rather than left to the OS to
|
||||||
|
/// reclaim. Per-model failures are logged and skipped — keep cleanup
|
||||||
|
/// going even when one harness is unhealthy.
|
||||||
|
pub async fn unload_all_models(registry: &HarnessRegistry) {
|
||||||
|
let listed = match registry.list_all_models().await {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "failed to list models during shutdown");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if listed.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(count = listed.len(), "unloading models for shutdown");
|
||||||
|
for model in listed {
|
||||||
|
let start = Instant::now();
|
||||||
|
match registry.unload_model(&model.id).await {
|
||||||
|
Ok(()) => tracing::info!(
|
||||||
|
model = %model.id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"unloaded"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::warn!(
|
||||||
|
model = %model.id,
|
||||||
|
error = %e,
|
||||||
|
"unload failed during shutdown"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
56
crates/neuron/tests/activation.rs
Normal file
56
crates/neuron/tests/activation.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//! Activation-time behaviour: load_default_models continues past
|
||||||
|
//! individual failures so a single broken catalogue entry doesn't
|
||||||
|
//! prevent the rest of the fleet from starting.
|
||||||
|
|
||||||
|
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
use neuron::harness::HarnessRegistry;
|
||||||
|
use neuron::startup;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_load_default_models_skips_unknown_harness() {
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Both entries fail synchronously inside the registry — no network
|
||||||
|
// call escapes (the harness lookup mismatches before hf-hub is
|
||||||
|
// touched). The function should still return cleanly.
|
||||||
|
let specs = vec![
|
||||||
|
ModelSpec {
|
||||||
|
model_id: "model-a".into(),
|
||||||
|
harness: "no-such-harness".into(),
|
||||||
|
quant: None,
|
||||||
|
tensor_parallel: None,
|
||||||
|
devices: None,
|
||||||
|
},
|
||||||
|
ModelSpec {
|
||||||
|
model_id: "model-b".into(),
|
||||||
|
harness: "no-such-harness".into(),
|
||||||
|
quant: None,
|
||||||
|
tensor_parallel: None,
|
||||||
|
devices: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
startup::load_default_models(®istry, &specs).await;
|
||||||
|
|
||||||
|
let listed = registry
|
||||||
|
.list_all_models()
|
||||||
|
.await
|
||||||
|
.expect("list_all_models should succeed");
|
||||||
|
assert!(
|
||||||
|
listed.is_empty(),
|
||||||
|
"no models should be loaded after failed entries"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_load_default_models_empty_is_noop() {
|
||||||
|
let registry = HarnessRegistry::new();
|
||||||
|
startup::load_default_models(®istry, &[]).await;
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ async fn spawn_neuron(discovery: DiscoveryResponse) -> String {
|
|||||||
discovery,
|
discovery,
|
||||||
health_cache,
|
health_cache,
|
||||||
registry: RwLock::new(registry),
|
registry: RwLock::new(registry),
|
||||||
|
candle: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(state);
|
||||||
@@ -135,56 +136,30 @@ async fn test_models_empty_registry() {
|
|||||||
assert!(body.as_array().unwrap().is_empty());
|
assert!(body.as_array().unwrap().is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spawn a mock mistral.rs backend and a neuron with the mistralrs harness
|
/// Verify the candle harness registers, list is empty by default, and a
|
||||||
/// pointing at it, then test the full model lifecycle through neuron's API.
|
/// load attempt for an obviously-bogus model id returns a 4xx error
|
||||||
|
/// without crashing the daemon. Real load/unload exercising actual GGUF
|
||||||
|
/// download is covered by `tests/candle_lifecycle.rs` (cuda-integration).
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_models_via_mistralrs_harness() {
|
async fn test_candle_harness_registers_and_rejects_bogus_model() {
|
||||||
use axum::routing::{get, post};
|
|
||||||
use axum::{Json, Router};
|
|
||||||
use cortex_core::harness::HarnessConfig;
|
use cortex_core::harness::HarnessConfig;
|
||||||
use serde_json::Value;
|
use neuron::config::HarnessSettings;
|
||||||
|
|
||||||
// Mock mistral.rs backend.
|
let registry = HarnessRegistry::from_configs(
|
||||||
let mock_app = Router::new()
|
&[HarnessConfig {
|
||||||
.route(
|
name: "candle".into(),
|
||||||
"/v1/models",
|
}],
|
||||||
get(|| async {
|
"http://localhost:13131",
|
||||||
Json(json!({
|
&HarnessSettings::default(),
|
||||||
"data": [
|
|
||||||
{"id": "test-model", "status": "loaded"},
|
|
||||||
{"id": "other-model", "status": "unloaded"}
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/v1/models/unload",
|
|
||||||
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/v1/models/reload",
|
|
||||||
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let mock_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let candle = registry.candle();
|
||||||
let mock_addr = mock_listener.local_addr().unwrap();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
axum::serve(mock_listener, mock_app).await.unwrap();
|
|
||||||
});
|
|
||||||
let mock_url = format!("http://{mock_addr}");
|
|
||||||
|
|
||||||
// Build neuron with mistralrs harness pointing at mock.
|
|
||||||
let registry = HarnessRegistry::from_configs(&[HarnessConfig {
|
|
||||||
name: "mistralrs".into(),
|
|
||||||
endpoint: Some(mock_url.clone()),
|
|
||||||
systemd_unit: None,
|
|
||||||
}]);
|
|
||||||
|
|
||||||
let health_cache = Arc::new(HealthCache::new());
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
let state = Arc::new(NeuronState {
|
let state = Arc::new(NeuronState {
|
||||||
discovery: fake_discovery(),
|
discovery: fake_discovery(),
|
||||||
health_cache,
|
health_cache,
|
||||||
registry: RwLock::new(registry),
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(state);
|
||||||
@@ -197,7 +172,6 @@ async fn test_models_via_mistralrs_harness() {
|
|||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
// GET /models — should return models from mock mistralrs.
|
|
||||||
let resp = client
|
let resp = client
|
||||||
.get(format!("{neuron_url}/models"))
|
.get(format!("{neuron_url}/models"))
|
||||||
.send()
|
.send()
|
||||||
@@ -205,45 +179,140 @@ async fn test_models_via_mistralrs_harness() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||||
assert_eq!(models.len(), 2);
|
assert!(models.is_empty());
|
||||||
assert_eq!(models[0]["id"], "test-model");
|
|
||||||
assert_eq!(models[0]["harness"], "mistralrs");
|
|
||||||
assert_eq!(models[0]["status"], "loaded");
|
|
||||||
assert_eq!(models[1]["id"], "other-model");
|
|
||||||
assert_eq!(models[1]["status"], "unloaded");
|
|
||||||
|
|
||||||
// GET /models/test-model/endpoint — should return mock URL.
|
// Sending a wrong-harness spec should be rejected synchronously
|
||||||
let resp = client
|
// without touching the network or the model registry.
|
||||||
.get(format!("{neuron_url}/models/test-model/endpoint"))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(resp.status(), 200);
|
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
|
||||||
assert_eq!(body["url"], mock_url);
|
|
||||||
|
|
||||||
// POST /models/unload — should succeed.
|
|
||||||
let resp = client
|
|
||||||
.post(format!("{neuron_url}/models/unload"))
|
|
||||||
.json(&json!({"model_id": "test-model"}))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(resp.status(), 200);
|
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
|
||||||
assert_eq!(body["status"], "unloaded");
|
|
||||||
|
|
||||||
// POST /models/load — should succeed.
|
|
||||||
let resp = client
|
let resp = client
|
||||||
.post(format!("{neuron_url}/models/load"))
|
.post(format!("{neuron_url}/models/load"))
|
||||||
|
.json(&json!({"model_id": "definitely/not-real", "harness": "not-candle"}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 400);
|
||||||
|
|
||||||
|
// Registry still empty.
|
||||||
|
let resp = client
|
||||||
|
.get(format!("{neuron_url}/models"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||||
|
assert!(models.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `/v1/chat/completions` returns 503 when no candle harness is registered.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completions_no_candle_harness() {
|
||||||
|
let registry = HarnessRegistry::new();
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle: None,
|
||||||
|
});
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{url}/v1/chat/completions"))
|
||||||
.json(&json!({
|
.json(&json!({
|
||||||
"model_id": "test-model",
|
"model": "anything",
|
||||||
"harness": "mistralrs"
|
"messages": [{"role": "user", "content": "hi"}]
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 503);
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
}
|
||||||
assert_eq!(body["status"], "loaded");
|
|
||||||
|
/// `/v1/chat/completions` returns 404 when the requested model isn't loaded.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completions_model_not_loaded() {
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
let candle = registry.candle();
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
|
});
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{url}/v1/chat/completions"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "definitely/not-loaded",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}]
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 404);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `/v1/chat/completions` with `stream: true` returns 404 when the
|
||||||
|
/// model isn't loaded — same surface as the non-streaming path. The
|
||||||
|
/// streaming code only kicks in once the model lookup succeeds.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completions_streaming_model_not_loaded() {
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
let candle = registry.candle();
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
|
});
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{url}/v1/chat/completions"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "definitely/not-loaded",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"stream": true
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 404);
|
||||||
}
|
}
|
||||||
|
|||||||
87
crates/neuron/tests/candle_lifecycle.rs
Normal file
87
crates/neuron/tests/candle_lifecycle.rs
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
//! Real model load/unload lifecycle through the candle harness.
|
||||||
|
//!
|
||||||
|
//! Gated behind the `cuda-integration` feature because it downloads a
|
||||||
|
//! real (small) GGUF from HuggingFace and materialises tensors on the
|
||||||
|
//! configured device. Run on a host with network access and either a
|
||||||
|
//! CUDA GPU (when built with `--features cuda`) or enough CPU RAM to
|
||||||
|
//! hold the model.
|
||||||
|
//!
|
||||||
|
//! Usage:
|
||||||
|
//! cargo test -p neuron --features cuda-integration --test candle_lifecycle
|
||||||
|
//!
|
||||||
|
//! Optional environment variables:
|
||||||
|
//! NEURON_TEST_MODEL_ID — HuggingFace repo to load (default: a small
|
||||||
|
//! public Qwen3 GGUF repo).
|
||||||
|
//! NEURON_TEST_QUANT — quant substring matched against GGUF
|
||||||
|
//! filenames (default: "Q4_K_M").
|
||||||
|
//! HF_HOME — HuggingFace cache directory.
|
||||||
|
|
||||||
|
#![cfg(feature = "cuda-integration")]
|
||||||
|
|
||||||
|
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
use neuron::harness::HarnessRegistry;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_candle_qwen3_load_unload_lifecycle() {
|
||||||
|
let _ = tracing_subscriber::fmt()
|
||||||
|
.with_test_writer()
|
||||||
|
.with_env_filter("info,neuron=debug")
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
let model_id = std::env::var("NEURON_TEST_MODEL_ID")
|
||||||
|
.unwrap_or_else(|_| "Qwen/Qwen3-0.6B-GGUF".to_string());
|
||||||
|
let quant = std::env::var("NEURON_TEST_QUANT").unwrap_or_else(|_| "Q4_K_M".to_string());
|
||||||
|
|
||||||
|
let mut settings = HarnessSettings::default();
|
||||||
|
if let Ok(home) = std::env::var("HF_HOME") {
|
||||||
|
settings.candle.hf_cache = Some(PathBuf::from(home));
|
||||||
|
}
|
||||||
|
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:13131",
|
||||||
|
&settings,
|
||||||
|
);
|
||||||
|
|
||||||
|
let spec = ModelSpec {
|
||||||
|
model_id: model_id.clone(),
|
||||||
|
harness: "candle".into(),
|
||||||
|
quant: Some(quant),
|
||||||
|
tensor_parallel: None,
|
||||||
|
devices: Some(vec![0]),
|
||||||
|
};
|
||||||
|
|
||||||
|
registry
|
||||||
|
.load_model(&spec)
|
||||||
|
.await
|
||||||
|
.expect("load_model should succeed");
|
||||||
|
|
||||||
|
let models = registry.list_all_models().await.expect("list_all_models");
|
||||||
|
assert_eq!(models.len(), 1, "expected exactly one loaded model");
|
||||||
|
assert_eq!(models[0].id, model_id);
|
||||||
|
assert_eq!(models[0].harness, "candle");
|
||||||
|
assert_eq!(models[0].status, "loaded");
|
||||||
|
|
||||||
|
let url = registry.inference_endpoint(&model_id).await;
|
||||||
|
assert_eq!(url, Some("http://localhost:13131".into()));
|
||||||
|
|
||||||
|
// Re-loading the same model should be rejected.
|
||||||
|
let again = registry.load_model(&spec).await;
|
||||||
|
assert!(again.is_err(), "second load should error");
|
||||||
|
|
||||||
|
registry
|
||||||
|
.unload_model(&model_id)
|
||||||
|
.await
|
||||||
|
.expect("unload_model should succeed");
|
||||||
|
|
||||||
|
let models = registry.list_all_models().await.expect("list_all_models");
|
||||||
|
assert!(models.is_empty(), "registry should be empty after unload");
|
||||||
|
|
||||||
|
// Unloading a model that isn't loaded should error.
|
||||||
|
let err = registry.unload_model(&model_id).await;
|
||||||
|
assert!(err.is_err(), "unload of missing model should error");
|
||||||
|
}
|
||||||
32
crates/neuron/tests/shutdown.rs
Normal file
32
crates/neuron/tests/shutdown.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
//! Deactivation behaviour: unload_all_models tolerates an empty
|
||||||
|
//! registry and continues past per-model unload failures.
|
||||||
|
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
use neuron::harness::HarnessRegistry;
|
||||||
|
use neuron::startup;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unload_all_models_empty_registry_is_noop() {
|
||||||
|
let registry = HarnessRegistry::new();
|
||||||
|
startup::unload_all_models(®istry).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unload_all_models_with_no_loaded_models() {
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
|
||||||
|
startup::unload_all_models(®istry).await;
|
||||||
|
|
||||||
|
let listed = registry
|
||||||
|
.list_all_models()
|
||||||
|
.await
|
||||||
|
.expect("list_all_models should still succeed after shutdown cleanup");
|
||||||
|
assert!(listed.is_empty());
|
||||||
|
}
|
||||||
145
crates/neuron/tests/tp_worker_lifecycle.rs
Normal file
145
crates/neuron/tests/tp_worker_lifecycle.rs
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
//! Stage 7a-i: confirm the TP worker subprocess lifecycle round-trips.
|
||||||
|
//!
|
||||||
|
//! Spawns two worker subprocesses via the leader→worker stdio RPC,
|
||||||
|
//! pings each, and cleanly shuts them down. No CUDA required —
|
||||||
|
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
|
||||||
|
//! runs on any host the workspace builds on.
|
||||||
|
|
||||||
|
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
|
||||||
|
|
||||||
|
/// Path to the neuron binary built by cargo for this test process.
|
||||||
|
/// cargo populates `CARGO_BIN_EXE_neuron` at compile time for sibling-
|
||||||
|
/// binary tests; production paths in main.rs use `/proc/self/exe`.
|
||||||
|
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
|
||||||
|
|
||||||
|
/// Two workers (so we spawn one subprocess: rank 0 is in-process,
|
||||||
|
/// rank 1 is the child). Verify the spawned worker responds to Ping
|
||||||
|
/// with its own identity, then shut it down cleanly.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_spawn_ping_shutdown() {
|
||||||
|
// cuda_devices: rank 0 → device 0 (leader, unused here),
|
||||||
|
// rank 1 → device 1 (worker; not actually opened in 7a-i).
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
|
||||||
|
.await
|
||||||
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
let pongs = pool.ping_all().await.expect("ping all workers");
|
||||||
|
assert_eq!(pongs.len(), 1, "expected one Pong (rank 1 only)");
|
||||||
|
match &pongs[0] {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
assert_eq!(*rank, 1);
|
||||||
|
assert_eq!(*world_size, 2);
|
||||||
|
assert_eq!(*cuda_device, 1);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.shutdown().await.expect("clean shutdown");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Three workers — exercise the loop in `ping_all` / `shutdown`.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_spawn_three_workers() {
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2])
|
||||||
|
.await
|
||||||
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
let pongs = pool.ping_all().await.expect("ping all workers");
|
||||||
|
assert_eq!(pongs.len(), 2, "expected two Pongs (ranks 1 and 2)");
|
||||||
|
for (i, resp) in pongs.iter().enumerate() {
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
let expected_rank = (i + 1) as u32;
|
||||||
|
assert_eq!(*rank, expected_rank);
|
||||||
|
assert_eq!(*world_size, 3);
|
||||||
|
assert_eq!(*cuda_device, expected_rank);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.shutdown().await.expect("clean shutdown");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 7a-ii: without the cuda feature, Init must fail with a clear
|
||||||
|
/// `cuda_feature_not_enabled` marker rather than silently succeeding.
|
||||||
|
/// This is the local-dev-box test; the real NCCL handshake is exercised
|
||||||
|
/// by `tp_worker_lifecycle_cuda.rs` (gated on `cuda-integration`).
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_init_returns_cuda_feature_not_enabled_without_cuda() {
|
||||||
|
use neuron::harness::tp::rpc::WorkerRequest;
|
||||||
|
use std::process::Stdio;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
use tokio::process::Command;
|
||||||
|
|
||||||
|
// Spawn a single worker by hand to send Init directly (the pool's
|
||||||
|
// public API doesn't expose Init yet — that lands in 7a-ii).
|
||||||
|
let mut child = Command::new(NEURON_BIN)
|
||||||
|
.arg("--worker")
|
||||||
|
.arg("--rank")
|
||||||
|
.arg("1")
|
||||||
|
.arg("--tp-size")
|
||||||
|
.arg("2")
|
||||||
|
.arg("--cuda-device")
|
||||||
|
.arg("1")
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::null())
|
||||||
|
.kill_on_drop(true)
|
||||||
|
.spawn()
|
||||||
|
.expect("spawn worker");
|
||||||
|
|
||||||
|
let mut stdin = child.stdin.take().expect("stdin");
|
||||||
|
let stdout = child.stdout.take().expect("stdout");
|
||||||
|
let mut lines = BufReader::new(stdout).lines();
|
||||||
|
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: "ff".repeat(128),
|
||||||
|
};
|
||||||
|
let mut payload = serde_json::to_string(&req).unwrap();
|
||||||
|
payload.push('\n');
|
||||||
|
stdin.write_all(payload.as_bytes()).await.unwrap();
|
||||||
|
stdin.flush().await.unwrap();
|
||||||
|
|
||||||
|
let reply = lines
|
||||||
|
.next_line()
|
||||||
|
.await
|
||||||
|
.expect("read line")
|
||||||
|
.expect("got line");
|
||||||
|
let resp: WorkerResponse = serde_json::from_str(&reply).expect("parse reply");
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::Error { kind, .. } => {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
// With cuda enabled the response depends on whether
|
||||||
|
// CUDA hardware is actually present. Accept either
|
||||||
|
// the success contract or a real NCCL failure.
|
||||||
|
let _ = kind;
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
assert_eq!(kind, "cuda_feature_not_enabled");
|
||||||
|
}
|
||||||
|
WorkerResponse::InitOk => {
|
||||||
|
// Real NCCL succeeded — only possible with cuda feature
|
||||||
|
// AND a working NCCL stack AND another rank actually
|
||||||
|
// joining. Don't fail; just acknowledge.
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
panic!("InitOk without cuda feature is impossible");
|
||||||
|
}
|
||||||
|
other => panic!("expected Error or InitOk, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean shutdown.
|
||||||
|
stdin.write_all(b"{\"op\":\"shutdown\"}\n").await.unwrap();
|
||||||
|
stdin.flush().await.unwrap();
|
||||||
|
let _ = lines.next_line().await; // Bye
|
||||||
|
let _ = child.wait().await;
|
||||||
|
}
|
||||||
43
crates/neuron/tests/tp_worker_lifecycle_cuda.rs
Normal file
43
crates/neuron/tests/tp_worker_lifecycle_cuda.rs
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
//! Stage 7a-ii: real NCCL handshake across the worker pool.
|
||||||
|
//!
|
||||||
|
//! Gated behind the `cuda-integration` feature because it requires
|
||||||
|
//! libnccl AND multiple CUDA devices on the running host. Run on
|
||||||
|
//! beast (2× RTX 5090) via:
|
||||||
|
//!
|
||||||
|
//! cargo test -p neuron --features cuda-integration \
|
||||||
|
//! --test tp_worker_lifecycle_cuda
|
||||||
|
//!
|
||||||
|
//! Steps: spawn N-1 workers, call `init_nccl`, run `nccl_sanity_check`
|
||||||
|
//! (every rank `all_reduce`s `1u32` with Sum; expected total =
|
||||||
|
//! world_size), shut down cleanly.
|
||||||
|
|
||||||
|
#![cfg(feature = "cuda-integration")]
|
||||||
|
|
||||||
|
use neuron::harness::tp::WorkerPool;
|
||||||
|
|
||||||
|
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_init_and_sanity_check_two_ranks() {
|
||||||
|
let _ = tracing_subscriber::fmt()
|
||||||
|
.with_test_writer()
|
||||||
|
.with_env_filter("info,neuron=debug")
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
// 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1.
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
|
||||||
|
.await
|
||||||
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
pool.ping_all().await.expect("pong all workers");
|
||||||
|
|
||||||
|
pool.init_nccl(0)
|
||||||
|
.await
|
||||||
|
.expect("init_nccl: NCCL handshake across all ranks");
|
||||||
|
|
||||||
|
pool.nccl_sanity_check()
|
||||||
|
.await
|
||||||
|
.expect("nccl_sanity_check: observed_sum == world_size on all ranks");
|
||||||
|
|
||||||
|
pool.shutdown().await.expect("clean shutdown");
|
||||||
|
}
|
||||||
7
data/cortex-firewalld.xml
Normal file
7
data/cortex-firewalld.xml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<service>
|
||||||
|
<short>cortex</short>
|
||||||
|
<description>Cortex — inference gateway for multi-node GPU clusters</description>
|
||||||
|
<port protocol="tcp" port="31313"/>
|
||||||
|
<port protocol="tcp" port="31314"/>
|
||||||
|
</service>
|
||||||
6
data/neuron-firewalld.xml
Normal file
6
data/neuron-firewalld.xml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<service>
|
||||||
|
<short>helexa-neuron</short>
|
||||||
|
<description>Neuron — per-node GPU discovery and harness daemon for cortex</description>
|
||||||
|
<port protocol="tcp" port="13131"/>
|
||||||
|
</service>
|
||||||
3
data/neuron-sysusers.conf
Normal file
3
data/neuron-sysusers.conf
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
g neuron - -
|
||||||
|
u neuron - "Neuron GPU node daemon" /var/lib/neuron /sbin/nologin
|
||||||
|
m neuron neuron
|
||||||
@@ -5,11 +5,27 @@ Wants=network-online.target
|
|||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
Type=simple
|
Type=simple
|
||||||
ExecStart=/usr/bin/neuron --config /etc/cortex/neuron.toml
|
ExecStart=/usr/bin/neuron --config /etc/neuron/neuron.toml
|
||||||
Restart=on-failure
|
Restart=on-failure
|
||||||
RestartSec=5
|
RestartSec=5
|
||||||
User=cortex
|
User=neuron
|
||||||
Group=cortex
|
Group=neuron
|
||||||
|
# /var/lib/neuron is the neuron user's $HOME — hf-hub writes its
|
||||||
|
# default cache there (~/.cache/huggingface/hub). Without this directive
|
||||||
|
# systemd doesn't create the directory and hf-hub downloads fail with
|
||||||
|
# "fetch GGUF <file>: failed to create cache dir".
|
||||||
|
StateDirectory=neuron
|
||||||
|
StateDirectoryMode=0755
|
||||||
|
# Loading default_models from neuron.toml happens before the HTTP
|
||||||
|
# listener binds; large models can take many minutes to download and
|
||||||
|
# materialise on first activation. systemd's default TimeoutStartSec
|
||||||
|
# (90s) is far too short; allow 30 minutes.
|
||||||
|
TimeoutStartSec=1800s
|
||||||
|
# On stop, neuron drains in-flight requests then unloads every model
|
||||||
|
# to release CUDA contexts cleanly. Allow generous time for big-model
|
||||||
|
# unloads; systemd will SIGKILL after this bound.
|
||||||
|
TimeoutStopSec=120s
|
||||||
|
KillSignal=SIGTERM
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
|
|||||||
101
helexa-neuron.spec
Normal file
101
helexa-neuron.spec
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
Name: helexa-neuron
|
||||||
|
Version: 0.1.16
|
||||||
|
Release: 1%{?dist}
|
||||||
|
Summary: Per-node GPU discovery and harness management daemon for cortex
|
||||||
|
# Package name disambiguates from Fedora's existing "neuron" package
|
||||||
|
# (NEURON neural simulation environment from Yale). Binary, systemd
|
||||||
|
# unit, and system user are still called "neuron" for brevity.
|
||||||
|
|
||||||
|
License: GPL-3.0-or-later
|
||||||
|
URL: https://git.lair.cafe/helexa/cortex
|
||||||
|
Source0: %{name}-%{version}.tar.gz
|
||||||
|
Source1: %{name}-%{version}-vendor.tar.gz
|
||||||
|
|
||||||
|
ExclusiveArch: x86_64
|
||||||
|
|
||||||
|
BuildRequires: rust >= 1.85
|
||||||
|
BuildRequires: cargo
|
||||||
|
BuildRequires: gcc
|
||||||
|
BuildRequires: gcc-c++
|
||||||
|
BuildRequires: cmake
|
||||||
|
BuildRequires: perl-interpreter
|
||||||
|
BuildRequires: pkgconfig(openssl)
|
||||||
|
BuildRequires: systemd-rpm-macros
|
||||||
|
|
||||||
|
Requires(pre): shadow-utils
|
||||||
|
Requires: systemd
|
||||||
|
Requires: firewalld-filesystem
|
||||||
|
|
||||||
|
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
|
||||||
|
# from our .service file and emits Requires: user(neuron)/group(neuron).
|
||||||
|
# rpm's sysusers provides-generator emits the unversioned form for groups
|
||||||
|
# but only a versioned user(neuron) = <base64> for users with GECOS/home/
|
||||||
|
# shell. Provide the unversioned user(neuron) explicitly so dnf can resolve
|
||||||
|
# the auto-generated Requires. Without this, dnf5 silently filters the
|
||||||
|
# package and reports "Nothing to do".
|
||||||
|
Provides: user(neuron)
|
||||||
|
|
||||||
|
%description
|
||||||
|
Neuron is a per-node daemon for cortex inference clusters. It discovers
|
||||||
|
local GPU hardware via nvidia-smi, runs in-process inference via
|
||||||
|
huggingface/candle, and exposes an HTTP API for model lifecycle
|
||||||
|
management (load, unload, list, inference endpoint).
|
||||||
|
|
||||||
|
%prep
|
||||||
|
%autosetup
|
||||||
|
tar xf %{SOURCE1}
|
||||||
|
mkdir -p .cargo
|
||||||
|
cat > .cargo/config.toml << 'EOF'
|
||||||
|
[source.crates-io]
|
||||||
|
replace-with = "vendored-sources"
|
||||||
|
|
||||||
|
[source.vendored-sources]
|
||||||
|
directory = "vendor"
|
||||||
|
EOF
|
||||||
|
|
||||||
|
%build
|
||||||
|
cargo build --release -p neuron
|
||||||
|
|
||||||
|
%install
|
||||||
|
install -Dm755 target/release/neuron %{buildroot}%{_bindir}/neuron
|
||||||
|
install -Dm644 data/neuron.service %{buildroot}%{_unitdir}/neuron.service
|
||||||
|
install -Dm644 data/neuron-sysusers.conf %{buildroot}%{_sysusersdir}/neuron.conf
|
||||||
|
install -Dm644 data/neuron-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/helexa-neuron.xml
|
||||||
|
install -dm755 %{buildroot}%{_sysconfdir}/neuron
|
||||||
|
install -Dm644 neuron.example.toml %{buildroot}%{_sysconfdir}/neuron/neuron.toml
|
||||||
|
|
||||||
|
%pre
|
||||||
|
%sysusers_create_compat %{_builddir}/%{name}-%{version}/data/neuron-sysusers.conf
|
||||||
|
|
||||||
|
%post
|
||||||
|
%systemd_post neuron.service
|
||||||
|
|
||||||
|
%preun
|
||||||
|
%systemd_preun neuron.service
|
||||||
|
|
||||||
|
%postun
|
||||||
|
%systemd_postun_with_restart neuron.service
|
||||||
|
|
||||||
|
%files
|
||||||
|
%license LICENSE
|
||||||
|
%doc README.md
|
||||||
|
%{_bindir}/neuron
|
||||||
|
%{_unitdir}/neuron.service
|
||||||
|
%{_sysusersdir}/neuron.conf
|
||||||
|
%{_prefix}/lib/firewalld/services/helexa-neuron.xml
|
||||||
|
%dir %{_sysconfdir}/neuron
|
||||||
|
%config(noreplace) %{_sysconfdir}/neuron/neuron.toml
|
||||||
|
|
||||||
|
%changelog
|
||||||
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
|
||||||
|
- chore: ignore local deploy script
|
||||||
|
- chore: move default ports out of common-collision ranges
|
||||||
|
- ci: drop actions/cache for cargo registry and target
|
||||||
|
|
||||||
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
|
||||||
|
- ci: publish both packages to a single helexa/helexa COPR project
|
||||||
|
- fix(rpm): rename neuron package to helexa-neuron
|
||||||
|
- ci: commit generated %changelog entries back to main
|
||||||
|
|
||||||
|
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
||||||
|
- Initial package
|
||||||
@@ -2,28 +2,49 @@
|
|||||||
#
|
#
|
||||||
# Copy to /etc/cortex/models.toml and adjust for your environment.
|
# Copy to /etc/cortex/models.toml and adjust for your environment.
|
||||||
# Describes how to serve each model. Cortex matches these profiles
|
# Describes how to serve each model. Cortex matches these profiles
|
||||||
# against discovered neuron topologies for placement decisions.
|
# against discovered neuron topologies for placement decisions; the
|
||||||
|
# resulting `(catalogue × topology)` set is what `GET /v1/models`
|
||||||
|
# returns and what the router can cold-load on demand.
|
||||||
|
#
|
||||||
|
# Field reference:
|
||||||
|
# id - HuggingFace model id, exact match.
|
||||||
|
# harness - which engine handles inference (currently "candle").
|
||||||
|
# quant - GGUF quantisation tag for the file in the HF repo
|
||||||
|
# (e.g. "Q4_K_M"). Omit/empty for the dense
|
||||||
|
# safetensors path. TP requires dense.
|
||||||
|
# vram_mb - rough estimate; advisory only, not enforced.
|
||||||
|
# min_devices - GPU count this profile needs. TP profiles use
|
||||||
|
# the same value as the tensor-parallel size.
|
||||||
|
# min_device_vram_mb - each device must meet this VRAM floor for the
|
||||||
|
# neuron to be considered "feasible".
|
||||||
|
# pinned_on - optional whitelist of neuron names. Non-empty
|
||||||
|
# narrows feasibility to just those neurons and
|
||||||
|
# protects the model from LRU eviction there.
|
||||||
|
|
||||||
|
# Tensor-parallel target — needs a neuron with at least 2 large GPUs.
|
||||||
|
# The example pins to a specific neuron name; adjust or remove the
|
||||||
|
# pinned_on entry for your own fleet.
|
||||||
[[models]]
|
[[models]]
|
||||||
id = "your-org/large-model"
|
id = "Qwen/Qwen3.6-27B"
|
||||||
harness = "mistralrs"
|
harness = "candle"
|
||||||
quant = "Q4_K_M"
|
vram_mb = 54000
|
||||||
vram_mb = 19000
|
|
||||||
min_devices = 2
|
min_devices = 2
|
||||||
min_device_vram_mb = 10000
|
min_device_vram_mb = 24000
|
||||||
pinned_on = ["gpu-large"]
|
pinned_on = ["your-multi-gpu-neuron"]
|
||||||
|
|
||||||
|
# Mid-size dense model — fits on any single GPU with ≥16 GB VRAM.
|
||||||
[[models]]
|
[[models]]
|
||||||
id = "your-org/medium-model"
|
id = "Qwen/Qwen3-8B"
|
||||||
harness = "mistralrs"
|
harness = "candle"
|
||||||
quant = "Q6_K"
|
vram_mb = 18000
|
||||||
vram_mb = 12000
|
|
||||||
min_devices = 1
|
min_devices = 1
|
||||||
pinned_on = ["gpu-medium"]
|
min_device_vram_mb = 16000
|
||||||
|
|
||||||
|
# Small GGUF quantised — runs on any small GPU.
|
||||||
[[models]]
|
[[models]]
|
||||||
id = "your-org/embedding-model"
|
id = "unsloth/Qwen3-0.6B-GGUF"
|
||||||
harness = "mistralrs"
|
harness = "candle"
|
||||||
quant = "Q8_0"
|
quant = "Q4_K_M"
|
||||||
vram_mb = 8000
|
vram_mb = 500
|
||||||
min_devices = 1
|
min_devices = 1
|
||||||
|
min_device_vram_mb = 4000
|
||||||
|
|||||||
@@ -1,16 +1,53 @@
|
|||||||
# neuron.example.toml — example configuration
|
# neuron.example.toml — example configuration
|
||||||
#
|
#
|
||||||
# Copy to /etc/cortex/neuron.toml and adjust for your environment.
|
# Copy to /etc/neuron/neuron.toml and adjust for your environment.
|
||||||
#
|
#
|
||||||
# Environment variable overrides use NEURON_ prefix with __ separators:
|
# Environment variable overrides use NEURON_ prefix with __ separators:
|
||||||
# NEURON_PORT=9090
|
# NEURON_PORT=13131
|
||||||
|
|
||||||
port = 9090
|
port = 13131
|
||||||
|
|
||||||
# -- Harnesses ---------------------------------------------------------------
|
# -- Harnesses ---------------------------------------------------------------
|
||||||
# Each [[harnesses]] entry declares an inference engine managed by neuron.
|
# Each [[harnesses]] entry enables an inference engine. Currently only
|
||||||
|
# "candle" is supported — it runs in-process and uses huggingface/candle
|
||||||
|
# for inference on local CUDA devices (or CPU when CUDA is unavailable).
|
||||||
|
|
||||||
[[harnesses]]
|
[[harnesses]]
|
||||||
name = "mistralrs"
|
name = "candle"
|
||||||
endpoint = "http://localhost:8080"
|
|
||||||
systemd_unit = "mistralrs.service"
|
# -- Candle harness settings -------------------------------------------------
|
||||||
|
# Optional tuning for the candle harness.
|
||||||
|
|
||||||
|
[harness.candle]
|
||||||
|
# HuggingFace cache directory for model weights.
|
||||||
|
#
|
||||||
|
# Resolution order (first hit wins):
|
||||||
|
# 1. `hf_cache` here in this file.
|
||||||
|
# 2. `HF_HUB_CACHE` env var — same convention as the Python
|
||||||
|
# `huggingface_hub` library, so an existing cache directory shared
|
||||||
|
# with other tooling can be reused without per-tool config.
|
||||||
|
# 3. `HF_HOME` env var (cache appended as `$HF_HOME/hub`).
|
||||||
|
# 4. hf-hub's default (`~/.cache/huggingface/hub`).
|
||||||
|
#
|
||||||
|
# For per-host overrides (e.g. one neuron has an SSD with prefetched
|
||||||
|
# weights), prefer a systemd drop-in over editing this file:
|
||||||
|
# /etc/systemd/system/neuron.service.d/local.conf:
|
||||||
|
# [Service]
|
||||||
|
# Environment=HF_HUB_CACHE=/archive/hf-cache
|
||||||
|
# hf_cache = "/var/lib/neuron/hf-cache"
|
||||||
|
|
||||||
|
# -- Default models ----------------------------------------------------------
|
||||||
|
# Models listed here are loaded automatically when the neuron service
|
||||||
|
# activates. Loading is sequential — a slow or failing entry doesn't
|
||||||
|
# block the rest of the fleet, but it does push out the time before
|
||||||
|
# neuron starts serving HTTP, so keep the list short. Operators can
|
||||||
|
# load additional models on demand via POST /models/load.
|
||||||
|
#
|
||||||
|
# Make sure data/neuron.service's TimeoutStartSec is generous enough to
|
||||||
|
# cover the slowest entry's first-time download + materialisation.
|
||||||
|
|
||||||
|
# [[default_models]]
|
||||||
|
# model_id = "Qwen/Qwen3-0.6B-GGUF"
|
||||||
|
# harness = "candle"
|
||||||
|
# quant = "Q4_K_M"
|
||||||
|
# devices = [0]
|
||||||
|
|||||||
81
neuron.spec
81
neuron.spec
@@ -1,81 +0,0 @@
|
|||||||
Name: neuron
|
|
||||||
Version: 0.1.2
|
|
||||||
Release: 1%{?dist}
|
|
||||||
Summary: Per-node GPU discovery and harness management daemon for cortex
|
|
||||||
|
|
||||||
License: GPL-3.0-or-later
|
|
||||||
URL: https://git.lair.cafe/helexa/cortex
|
|
||||||
Source0: %{name}-%{version}.tar.gz
|
|
||||||
Source1: %{name}-%{version}-vendor.tar.gz
|
|
||||||
|
|
||||||
ExclusiveArch: x86_64
|
|
||||||
|
|
||||||
BuildRequires: rust >= 1.85
|
|
||||||
BuildRequires: cargo
|
|
||||||
BuildRequires: gcc
|
|
||||||
BuildRequires: gcc-c++
|
|
||||||
BuildRequires: cmake
|
|
||||||
BuildRequires: perl-interpreter
|
|
||||||
BuildRequires: pkgconfig(openssl)
|
|
||||||
BuildRequires: systemd-rpm-macros
|
|
||||||
|
|
||||||
Requires(pre): shadow-utils
|
|
||||||
Requires: systemd
|
|
||||||
|
|
||||||
# rpm's sysusers provides-generator only emits versioned user(cortex) when
|
|
||||||
# the u-line has GECOS/home/shell fields. %attr(,,cortex) in %files emits
|
|
||||||
# an unversioned Requires: user(cortex), so we provide it explicitly.
|
|
||||||
Provides: user(cortex)
|
|
||||||
Provides: group(cortex)
|
|
||||||
|
|
||||||
%description
|
|
||||||
Neuron is a per-node daemon for cortex inference clusters. It discovers
|
|
||||||
local GPU hardware via nvidia-smi, manages inference harnesses (mistral.rs,
|
|
||||||
llama.cpp), and exposes an HTTP API for model lifecycle management.
|
|
||||||
|
|
||||||
%prep
|
|
||||||
%autosetup
|
|
||||||
tar xf %{SOURCE1}
|
|
||||||
mkdir -p .cargo
|
|
||||||
cat > .cargo/config.toml << 'EOF'
|
|
||||||
[source.crates-io]
|
|
||||||
replace-with = "vendored-sources"
|
|
||||||
|
|
||||||
[source.vendored-sources]
|
|
||||||
directory = "vendor"
|
|
||||||
EOF
|
|
||||||
|
|
||||||
%build
|
|
||||||
cargo build --release -p neuron
|
|
||||||
|
|
||||||
%install
|
|
||||||
install -Dm755 target/release/neuron %{buildroot}%{_bindir}/neuron
|
|
||||||
install -Dm644 data/neuron.service %{buildroot}%{_unitdir}/neuron.service
|
|
||||||
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/neuron.conf
|
|
||||||
install -dm750 %{buildroot}%{_sysconfdir}/cortex
|
|
||||||
install -Dm640 neuron.example.toml %{buildroot}%{_sysconfdir}/cortex/neuron.toml
|
|
||||||
|
|
||||||
%pre
|
|
||||||
%sysusers_create_compat %{_builddir}/%{name}-%{version}/data/cortex-sysusers.conf
|
|
||||||
|
|
||||||
%post
|
|
||||||
%systemd_post neuron.service
|
|
||||||
|
|
||||||
%preun
|
|
||||||
%systemd_preun neuron.service
|
|
||||||
|
|
||||||
%postun
|
|
||||||
%systemd_postun_with_restart neuron.service
|
|
||||||
|
|
||||||
%files
|
|
||||||
%license LICENSE
|
|
||||||
%doc README.md
|
|
||||||
%{_bindir}/neuron
|
|
||||||
%{_unitdir}/neuron.service
|
|
||||||
%{_sysusersdir}/neuron.conf
|
|
||||||
%dir %attr(750,root,cortex) %{_sysconfdir}/cortex
|
|
||||||
%config(noreplace) %attr(640,root,cortex) %{_sysconfdir}/cortex/neuron.toml
|
|
||||||
|
|
||||||
%changelog
|
|
||||||
* Tue Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
|
||||||
- Initial package
|
|
||||||
106
rpm/cortex-prerelease.spec
Normal file
106
rpm/cortex-prerelease.spec
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# Prebuilt-binary spec for cortex.
|
||||||
|
#
|
||||||
|
# Unlike cortex.spec (which builds from source via cargo), this spec
|
||||||
|
# wraps a pre-built `cortex` binary produced by an upstream CI job and
|
||||||
|
# packages it for rpm.lair.cafe. The %build phase is a no-op.
|
||||||
|
#
|
||||||
|
# Required defines at rpmbuild time:
|
||||||
|
# cortex_version e.g. "0.1.16"
|
||||||
|
# cortex_prerelease e.g. "0.1.20260518140530.gitabcdef0"
|
||||||
|
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
|
||||||
|
# commit time (sec) commit sha
|
||||||
|
# (used as Release; the timestamp prefix
|
||||||
|
# keeps same-day builds strictly ordered.)
|
||||||
|
|
||||||
|
%global _build_id_links none
|
||||||
|
%global debug_package %{nil}
|
||||||
|
%global __strip /usr/bin/true
|
||||||
|
|
||||||
|
%{!?cortex_version: %global cortex_version 0.0.0}
|
||||||
|
%if 0%{?cortex_prerelease:1}
|
||||||
|
%global cortex_release %{cortex_prerelease}
|
||||||
|
%else
|
||||||
|
%global cortex_release 1
|
||||||
|
%endif
|
||||||
|
|
||||||
|
Name: cortex
|
||||||
|
Version: %{cortex_version}
|
||||||
|
Release: %{cortex_release}%{?dist}
|
||||||
|
Summary: Inference gateway for multi-node GPU clusters (prebuilt)
|
||||||
|
|
||||||
|
License: GPL-3.0-or-later
|
||||||
|
URL: https://git.lair.cafe/helexa/cortex
|
||||||
|
|
||||||
|
Source0: cortex
|
||||||
|
Source1: cortex.service
|
||||||
|
Source2: cortex-sysusers.conf
|
||||||
|
Source3: cortex-firewalld.xml
|
||||||
|
Source4: cortex.example.toml
|
||||||
|
Source5: models.example.toml
|
||||||
|
Source6: LICENSE
|
||||||
|
|
||||||
|
ExclusiveArch: x86_64
|
||||||
|
|
||||||
|
Requires(pre): shadow-utils
|
||||||
|
Requires: systemd
|
||||||
|
Requires: firewalld-filesystem
|
||||||
|
|
||||||
|
Provides: user(cortex)
|
||||||
|
|
||||||
|
%description
|
||||||
|
Cortex is a Rust reverse-proxy that sits in front of multiple neuron
|
||||||
|
inference daemons and presents a unified OpenAI and Anthropic
|
||||||
|
compatible API surface.
|
||||||
|
|
||||||
|
This package wraps a binary built upstream in CI; the source-build
|
||||||
|
spec (cortex.spec) remains available for stable releases.
|
||||||
|
|
||||||
|
%prep
|
||||||
|
cp %{SOURCE0} ./cortex
|
||||||
|
cp %{SOURCE1} .
|
||||||
|
cp %{SOURCE2} .
|
||||||
|
cp %{SOURCE3} .
|
||||||
|
cp %{SOURCE4} .
|
||||||
|
cp %{SOURCE5} .
|
||||||
|
cp %{SOURCE6} .
|
||||||
|
|
||||||
|
%build
|
||||||
|
# Already built in the upstream CI build job.
|
||||||
|
|
||||||
|
%install
|
||||||
|
install -Dm755 cortex %{buildroot}%{_bindir}/cortex
|
||||||
|
install -Dm644 cortex.service %{buildroot}%{_unitdir}/cortex.service
|
||||||
|
install -Dm644 cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
|
||||||
|
install -Dm644 cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
|
install -dm755 %{buildroot}%{_sysconfdir}/cortex
|
||||||
|
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
|
||||||
|
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
|
%pre
|
||||||
|
getent group cortex >/dev/null || groupadd -r cortex
|
||||||
|
getent passwd cortex >/dev/null || \
|
||||||
|
useradd -r -g cortex -d /var/lib/cortex -s /sbin/nologin \
|
||||||
|
-c "Cortex inference gateway" cortex
|
||||||
|
|
||||||
|
%post
|
||||||
|
%systemd_post cortex.service
|
||||||
|
|
||||||
|
%preun
|
||||||
|
%systemd_preun cortex.service
|
||||||
|
|
||||||
|
%postun
|
||||||
|
%systemd_postun_with_restart cortex.service
|
||||||
|
|
||||||
|
%files
|
||||||
|
%license LICENSE
|
||||||
|
%{_bindir}/cortex
|
||||||
|
%{_unitdir}/cortex.service
|
||||||
|
%{_sysusersdir}/cortex.conf
|
||||||
|
%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
|
%dir %{_sysconfdir}/cortex
|
||||||
|
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
|
||||||
|
%config(noreplace) %{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
|
%changelog
|
||||||
|
* Mon May 18 2026 Gitea Actions <actions@git.lair.cafe> - %{cortex_version}-%{cortex_release}
|
||||||
|
- Prerelease build from upstream CI binary.
|
||||||
126
rpm/helexa-neuron-prerelease.spec
Normal file
126
rpm/helexa-neuron-prerelease.spec
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
# Prebuilt-binary spec for helexa-neuron flavoured by CUDA compute capability.
|
||||||
|
#
|
||||||
|
# Unlike helexa-neuron.spec (which builds from source via cargo), this
|
||||||
|
# spec wraps a pre-built `neuron-{flavour}` binary produced by an
|
||||||
|
# upstream CI job and packages it for rpm.lair.cafe. The %build phase
|
||||||
|
# is a no-op.
|
||||||
|
#
|
||||||
|
# Required defines at rpmbuild time:
|
||||||
|
# neuron_version e.g. "0.1.16"
|
||||||
|
# neuron_flavour e.g. "ada", "blackwell" — matches the CI build
|
||||||
|
# matrix's compute_cap label.
|
||||||
|
# neuron_prerelease e.g. "0.1.20260518140530.gitabcdef0"
|
||||||
|
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
|
||||||
|
# commit time (sec) commit sha
|
||||||
|
# (used as Release; the timestamp prefix
|
||||||
|
# keeps same-day builds strictly ordered.)
|
||||||
|
#
|
||||||
|
# One flavour can be installed at a time on a given host; flavour
|
||||||
|
# packages Conflict with each other.
|
||||||
|
|
||||||
|
%global _build_id_links none
|
||||||
|
%global debug_package %{nil}
|
||||||
|
%global __strip /usr/bin/true
|
||||||
|
|
||||||
|
%{!?neuron_version: %global neuron_version 0.0.0}
|
||||||
|
%{!?neuron_flavour: %global neuron_flavour blackwell}
|
||||||
|
%if 0%{?neuron_prerelease:1}
|
||||||
|
%global neuron_release %{neuron_prerelease}
|
||||||
|
%else
|
||||||
|
%global neuron_release 1
|
||||||
|
%endif
|
||||||
|
|
||||||
|
Name: helexa-neuron-%{neuron_flavour}
|
||||||
|
Version: %{neuron_version}
|
||||||
|
Release: %{neuron_release}%{?dist}
|
||||||
|
Summary: Per-node GPU inference daemon (candle, %{neuron_flavour} flavour)
|
||||||
|
|
||||||
|
License: GPL-3.0-or-later
|
||||||
|
URL: https://git.lair.cafe/helexa/cortex
|
||||||
|
|
||||||
|
Source0: neuron-%{neuron_flavour}
|
||||||
|
Source1: neuron.service
|
||||||
|
Source2: neuron-sysusers.conf
|
||||||
|
Source3: neuron-firewalld.xml
|
||||||
|
Source4: neuron.example.toml
|
||||||
|
Source5: LICENSE
|
||||||
|
|
||||||
|
ExclusiveArch: x86_64
|
||||||
|
|
||||||
|
# Binary links against the CUDA runtime, cuDNN, NCCL, etc. Suppress
|
||||||
|
# auto-detected exact soname deps — users may have CUDA from various
|
||||||
|
# sources (rpmfusion, nvidia-direct) at different compatible versions;
|
||||||
|
# a runtime dlopen failure surfaces a clearer error than rpm dep
|
||||||
|
# resolution would.
|
||||||
|
%global __requires_exclude ^lib(cuda|cudart|cudnn|cublas|cublasLt|curand|nvrtc|nccl)
|
||||||
|
|
||||||
|
Requires(pre): shadow-utils
|
||||||
|
Requires: systemd
|
||||||
|
Requires: firewalld-filesystem
|
||||||
|
|
||||||
|
Provides: helexa-neuron = %{neuron_version}-%{neuron_release}
|
||||||
|
Provides: user(neuron)
|
||||||
|
|
||||||
|
# Mutual exclusion across flavours and the source-build variant.
|
||||||
|
Conflicts: helexa-neuron
|
||||||
|
Conflicts: helexa-neuron-ada
|
||||||
|
Conflicts: helexa-neuron-ampere
|
||||||
|
Conflicts: helexa-neuron-blackwell
|
||||||
|
# (The Conflicts: with self is filtered by rpm at install time.)
|
||||||
|
|
||||||
|
%description
|
||||||
|
Neuron is the per-node daemon for cortex inference clusters. It
|
||||||
|
discovers local GPU hardware via nvidia-smi, runs in-process
|
||||||
|
inference via huggingface/candle, and exposes an HTTP API for model
|
||||||
|
lifecycle management (load, unload, list, inference endpoint).
|
||||||
|
|
||||||
|
This is the %{neuron_flavour} flavour, built for that CUDA compute
|
||||||
|
capability. Install the flavour matching the GPUs on this host.
|
||||||
|
|
||||||
|
%prep
|
||||||
|
cp %{SOURCE0} ./neuron
|
||||||
|
cp %{SOURCE1} .
|
||||||
|
cp %{SOURCE2} .
|
||||||
|
cp %{SOURCE3} .
|
||||||
|
cp %{SOURCE4} .
|
||||||
|
cp %{SOURCE5} .
|
||||||
|
|
||||||
|
%build
|
||||||
|
# Already built in the upstream CI build job (with --features cuda).
|
||||||
|
|
||||||
|
%install
|
||||||
|
install -Dm755 neuron %{buildroot}%{_bindir}/neuron
|
||||||
|
install -Dm644 neuron.service %{buildroot}%{_unitdir}/neuron.service
|
||||||
|
install -Dm644 neuron-sysusers.conf %{buildroot}%{_sysusersdir}/neuron.conf
|
||||||
|
install -Dm644 neuron-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/helexa-neuron.xml
|
||||||
|
install -dm755 %{buildroot}%{_sysconfdir}/neuron
|
||||||
|
install -Dm644 neuron.example.toml %{buildroot}%{_sysconfdir}/neuron/neuron.toml
|
||||||
|
|
||||||
|
%pre
|
||||||
|
getent group neuron >/dev/null || groupadd -r neuron
|
||||||
|
getent passwd neuron >/dev/null || \
|
||||||
|
useradd -r -g neuron -d /var/lib/neuron -s /sbin/nologin \
|
||||||
|
-G video,render \
|
||||||
|
-c "Neuron GPU node daemon" neuron
|
||||||
|
|
||||||
|
%post
|
||||||
|
%systemd_post neuron.service
|
||||||
|
|
||||||
|
%preun
|
||||||
|
%systemd_preun neuron.service
|
||||||
|
|
||||||
|
%postun
|
||||||
|
%systemd_postun_with_restart neuron.service
|
||||||
|
|
||||||
|
%files
|
||||||
|
%license LICENSE
|
||||||
|
%{_bindir}/neuron
|
||||||
|
%{_unitdir}/neuron.service
|
||||||
|
%{_sysusersdir}/neuron.conf
|
||||||
|
%{_prefix}/lib/firewalld/services/helexa-neuron.xml
|
||||||
|
%dir %{_sysconfdir}/neuron
|
||||||
|
%config(noreplace) %{_sysconfdir}/neuron/neuron.toml
|
||||||
|
|
||||||
|
%changelog
|
||||||
|
* Mon May 18 2026 Gitea Actions <actions@git.lair.cafe> - %{neuron_version}-%{neuron_release}
|
||||||
|
- Prerelease build from upstream CI binary (%{neuron_flavour} flavour).
|
||||||
1
rpm/rpmmacros
Normal file
1
rpm/rpmmacros
Normal file
@@ -0,0 +1 @@
|
|||||||
|
%_openpgp_sign_id @GPG_NAME@
|
||||||
275
script/deploy.sh
Executable file
275
script/deploy.sh
Executable file
@@ -0,0 +1,275 @@
|
|||||||
|
#!/bin/env bash
|
||||||
|
#
|
||||||
|
# Rolling deploy across the helexa fleet, driven by asset/manifest.yml.
|
||||||
|
# Installs / upgrades cortex on the gateway host and the appropriate
|
||||||
|
# helexa-neuron-<flavour> package on each neuron host, then restarts
|
||||||
|
# their services.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
REPO_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||||
|
MANIFEST="${REPO_DIR}/asset/manifest.yml"
|
||||||
|
|
||||||
|
if [[ ! -f "${MANIFEST}" ]]; then
|
||||||
|
echo "fatal: manifest not found at ${MANIFEST}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Parse the manifest with yq. NOTE: this expects the pip-installed yq
|
||||||
|
# (a jq wrapper using jq syntax) — `pip install yq`. The Fedora rpm
|
||||||
|
# `yq` is mikefarah/yq and uses different (yaml-native) syntax; if a
|
||||||
|
# host has that one instead these queries will fail.
|
||||||
|
cortex_host=$(yq -r '.cortex.host' "${MANIFEST}")
|
||||||
|
|
||||||
|
# Emit one TAB-separated 'host\tflavour' line per neuron.
|
||||||
|
mapfile -t neuron_entries < <(
|
||||||
|
yq -r '.neurons[] | .host + "\t" + .flavour' "${MANIFEST}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the installed package's "version-release" string, or
|
||||||
|
# "(not installed)" when rpm reports the package as absent. Capture
|
||||||
|
# rpm's output into a variable so its "package X is not installed"
|
||||||
|
# stdout message (rpm writes that to stdout, not stderr, when -q fails)
|
||||||
|
# doesn't leak into the result.
|
||||||
|
installed_nvr() {
|
||||||
|
local host="$1" pkg="$2"
|
||||||
|
local nvr
|
||||||
|
if nvr=$(ssh "${host}" "rpm -q --qf '%{version}-%{release}' ${pkg} 2>/dev/null"); then
|
||||||
|
echo "${nvr}"
|
||||||
|
else
|
||||||
|
echo "(not installed)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure the rpm.lair.cafe unstable repo is configured AND enabled on
|
||||||
|
# the remote host.
|
||||||
|
#
|
||||||
|
# The upstream .repo file at https://rpm.lair.cafe/lair-cafe-unstable.repo
|
||||||
|
# ships with `enabled=0` so a host that just fetched it won't start
|
||||||
|
# pulling unstable packages by accident. We have to explicitly flip
|
||||||
|
# enabled=1 via `dnf config-manager setopt`. Both addrepo and setopt
|
||||||
|
# are idempotent.
|
||||||
|
#
|
||||||
|
# Non-fatal — if either step fails the subsequent `dnf install` will
|
||||||
|
# surface a clearer diagnostic on its own.
|
||||||
|
ensure_lair_repo() {
|
||||||
|
local host="$1"
|
||||||
|
if ! ssh "${host}" "test -f /etc/yum.repos.d/lair-cafe-unstable.repo" 2>/dev/null; then
|
||||||
|
echo "[${host}] adding rpm.lair.cafe unstable repo"
|
||||||
|
if ! ssh "${host}" sudo dnf config-manager addrepo \
|
||||||
|
--from-repofile=https://rpm.lair.cafe/lair-cafe-unstable.repo \
|
||||||
|
>/dev/null 2>&1; then
|
||||||
|
echo "[${host}] WARNING: failed to add lair.cafe repo file (proceeding anyway)"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
# The .repo file ships enabled=0; flip it on. Cheap, idempotent.
|
||||||
|
if ! ssh "${host}" sudo dnf config-manager setopt \
|
||||||
|
lair-cafe-unstable.enabled=1 >/dev/null 2>&1; then
|
||||||
|
echo "[${host}] WARNING: failed to enable lair-cafe-unstable (proceeding anyway)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure libcudnn.so.9 is resolvable on the remote host so the
|
||||||
|
# neuron binary (built with --features cudnn) doesn't fail at startup
|
||||||
|
# with "cannot open shared object file: No such file or directory".
|
||||||
|
#
|
||||||
|
# Probes ldconfig first — if cuDNN was installed manually (.tar/.run
|
||||||
|
# install), it'll be cached by ldconfig and we don't touch it.
|
||||||
|
# Otherwise adds NVIDIA's RHEL9 CUDA repo (the Fedora 43 CUDA repo
|
||||||
|
# doesn't ship cuDNN packages — only the RHEL9 one does) and installs
|
||||||
|
# libcudnn9-cuda-13.
|
||||||
|
ensure_cudnn_runtime() {
|
||||||
|
local host="$1"
|
||||||
|
if ssh "${host}" "ldconfig -p | grep -q libcudnn.so.9" 2>/dev/null; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
echo "[${host}] installing cuDNN runtime"
|
||||||
|
if ! ssh "${host}" "test -f /etc/yum.repos.d/cuda-rhel9-x86_64.repo" 2>/dev/null; then
|
||||||
|
if ! ssh "${host}" sudo dnf config-manager addrepo \
|
||||||
|
--from-repofile=https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
|
||||||
|
>/dev/null 2>&1; then
|
||||||
|
echo "[${host}] WARNING: failed to add rhel9 CUDA repo (proceeding anyway)"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
if ! ssh "${host}" sudo dnf install -y libcudnn9-cuda-13 >/dev/null 2>&1; then
|
||||||
|
echo "[${host}] WARNING: failed to install libcudnn9-cuda-13"
|
||||||
|
echo "[${host}] neuron may fail to start; install cuDNN manually if so"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# True when the named package needs to be installed or upgraded on the
|
||||||
|
# remote host — either it's not present, or a newer version exists in
|
||||||
|
# the repo. False only when the installed version is current.
|
||||||
|
#
|
||||||
|
# `dnf check-update <pkg>` returns 0 when the package isn't installed
|
||||||
|
# at all (there's nothing to update), so we have to probe with rpm -q
|
||||||
|
# first to distinguish "absent" from "current". Other dnf failures
|
||||||
|
# collapse into "needs update" so the subsequent install step surfaces
|
||||||
|
# the real diagnostic rather than this check swallowing it.
|
||||||
|
needs_update() {
|
||||||
|
local host="$1" pkg="$2"
|
||||||
|
# Not installed → needs work.
|
||||||
|
if ! ssh "${host}" "rpm -q ${pkg}" >/dev/null 2>&1; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
# Installed; ask dnf whether the repo has something newer.
|
||||||
|
if ssh "${host}" sudo dnf check-update --refresh -q "${pkg}" >/dev/null 2>&1; then
|
||||||
|
return 1
|
||||||
|
else
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# True if the named package is currently installed on the remote host.
|
||||||
|
# Used to decide between `dnf install` (fresh) and `dnf upgrade` (stale):
|
||||||
|
# dnf5's `install` is a no-op when the package is already present at
|
||||||
|
# any version — it does NOT auto-upgrade to the latest available — so
|
||||||
|
# the wrong command silently leaves the host on an old build.
|
||||||
|
is_installed() {
|
||||||
|
local host="$1" pkg="$2"
|
||||||
|
ssh "${host}" "rpm -q ${pkg}" >/dev/null 2>&1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Install or upgrade the named package on the remote, picking the
|
||||||
|
# right dnf verb based on the installed-or-not state. Returns 0 with
|
||||||
|
# dnf's combined stdout/stderr captured in __DNF_OUTPUT__ on success,
|
||||||
|
# and 1 with the same captured output on failure.
|
||||||
|
__DNF_OUTPUT__=""
|
||||||
|
install_or_upgrade() {
|
||||||
|
local host="$1" pkg="$2"
|
||||||
|
local cmd
|
||||||
|
if is_installed "${host}" "${pkg}"; then
|
||||||
|
cmd="upgrade"
|
||||||
|
else
|
||||||
|
cmd="install"
|
||||||
|
fi
|
||||||
|
if __DNF_OUTPUT__=$(
|
||||||
|
ssh "${host}" sudo dnf "${cmd}" --refresh --allowerasing -y "${pkg}" 2>&1
|
||||||
|
); then
|
||||||
|
return 0
|
||||||
|
else
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# cortex (gateway)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
ensure_lair_repo "${cortex_host}"
|
||||||
|
cortex_nvr=$(installed_nvr "${cortex_host}" cortex)
|
||||||
|
if needs_update "${cortex_host}" cortex; then
|
||||||
|
echo "[${cortex_host}] cortex update available (current: ${cortex_nvr})"
|
||||||
|
# Stop the service only if the unit file exists — fresh installs
|
||||||
|
# don't have it, and `systemctl stop` on a missing unit returns
|
||||||
|
# non-zero, which would otherwise short-circuit the install branch
|
||||||
|
# under set -e.
|
||||||
|
if ssh "${cortex_host}" "[ ! -f /usr/lib/systemd/system/cortex.service ] || sudo systemctl stop cortex.service"; then
|
||||||
|
echo "[${cortex_host}] stopped cortex service"
|
||||||
|
if install_or_upgrade "${cortex_host}" cortex; then
|
||||||
|
cortex_nvr=$(installed_nvr "${cortex_host}" cortex)
|
||||||
|
echo "[${cortex_host}] installed/upgraded cortex to ${cortex_nvr}"
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] failed to install/upgrade cortex:"
|
||||||
|
echo "${__DNF_OUTPUT__}" | sed "s/^/[${cortex_host}] /"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] failed to stop cortex service"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] cortex is up to date (${cortex_nvr})"
|
||||||
|
ssh "${cortex_host}" sudo systemctl stop cortex.service || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Sync cortex.toml whether the package was upgraded or not — the config
|
||||||
|
# can change without a package bump.
|
||||||
|
if rsync \
|
||||||
|
--archive \
|
||||||
|
--compress \
|
||||||
|
--rsync-path 'sudo rsync' \
|
||||||
|
--chown root:root \
|
||||||
|
--chmod 644 \
|
||||||
|
"${REPO_DIR}/cortex.toml" \
|
||||||
|
"${cortex_host}:/etc/cortex/cortex.toml"; then
|
||||||
|
echo "[${cortex_host}] sync'd cortex.toml"
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] failed to sync cortex.toml"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Sync models.toml on the same lifecycle as cortex.toml — operator-owned,
|
||||||
|
# gitignored, drives /v1/models catalogue × topology resolution.
|
||||||
|
if [[ -f "${REPO_DIR}/models.toml" ]]; then
|
||||||
|
if rsync \
|
||||||
|
--archive \
|
||||||
|
--compress \
|
||||||
|
--rsync-path 'sudo rsync' \
|
||||||
|
--chown root:root \
|
||||||
|
--chmod 644 \
|
||||||
|
"${REPO_DIR}/models.toml" \
|
||||||
|
"${cortex_host}:/etc/cortex/models.toml"; then
|
||||||
|
echo "[${cortex_host}] sync'd models.toml"
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] failed to sync models.toml"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] no local models.toml — leaving /etc/cortex/models.toml untouched"
|
||||||
|
fi
|
||||||
|
|
||||||
|
ssh "${cortex_host}" sudo systemctl daemon-reload
|
||||||
|
if ssh "${cortex_host}" systemctl is-active --quiet cortex.service; then
|
||||||
|
echo "[${cortex_host}] cortex service is active"
|
||||||
|
elif ssh "${cortex_host}" sudo systemctl start cortex.service; then
|
||||||
|
echo "[${cortex_host}] started cortex service"
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] failed to start cortex service"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# neuron (per-host, flavour from manifest)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
for entry in "${neuron_entries[@]}"; do
|
||||||
|
IFS=$'\t' read -r neuron_host neuron_flavour <<< "${entry}"
|
||||||
|
package="helexa-neuron-${neuron_flavour}"
|
||||||
|
|
||||||
|
ensure_lair_repo "${neuron_host}"
|
||||||
|
ensure_cudnn_runtime "${neuron_host}"
|
||||||
|
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
|
||||||
|
if needs_update "${neuron_host}" "${package}"; then
|
||||||
|
echo "[${neuron_host}] ${package} update available (current: ${neuron_nvr})"
|
||||||
|
if ssh "${neuron_host}" "[ ! -f /usr/lib/systemd/system/neuron.service ] || sudo systemctl stop neuron.service"; then
|
||||||
|
echo "[${neuron_host}] stopped neuron service"
|
||||||
|
# --allowerasing lets dnf swap out a previously-installed
|
||||||
|
# bare helexa-neuron or a different flavour without manual
|
||||||
|
# intervention. The Conflicts: clauses in the spec ensure
|
||||||
|
# only one flavour is ever resident.
|
||||||
|
if install_or_upgrade "${neuron_host}" "${package}"; then
|
||||||
|
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
|
||||||
|
echo "[${neuron_host}] installed/upgraded ${package} to ${neuron_nvr}"
|
||||||
|
# Ensure firewalld allows neuron port
|
||||||
|
ssh "${neuron_host}" "sudo firewall-cmd --query-service=helexa-neuron --quiet 2>/dev/null || sudo firewall-cmd --add-service=helexa-neuron --permanent && sudo firewall-cmd --reload" 2>/dev/null || true
|
||||||
|
if ssh "${neuron_host}" "sudo systemctl daemon-reload && sudo systemctl start neuron.service"; then
|
||||||
|
echo "[${neuron_host}] started neuron service"
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] failed to start neuron service"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] failed to install ${package}:"
|
||||||
|
echo "${__DNF_OUTPUT__}" | sed "s/^/[${neuron_host}] /"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] failed to stop neuron service"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] ${package} is up to date (${neuron_nvr})"
|
||||||
|
if ssh "${neuron_host}" systemctl is-active --quiet neuron.service; then
|
||||||
|
echo "[${neuron_host}] neuron service is active"
|
||||||
|
elif ssh "${neuron_host}" sudo systemctl start neuron.service; then
|
||||||
|
echo "[${neuron_host}] started neuron service"
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] failed to start neuron service"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
154
script/generate-packages-json.py
Executable file
154
script/generate-packages-json.py
Executable file
@@ -0,0 +1,154 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Parse RPM repodata and emit a packages.json manifest for the UI."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
RPM_NS = "http://linux.duke.edu/metadata/common"
|
||||||
|
OTHER_NS = "http://linux.duke.edu/metadata/other"
|
||||||
|
REPO_NS = "http://linux.duke.edu/metadata/repo"
|
||||||
|
|
||||||
|
|
||||||
|
def find_repodata_file(repodata_dir, data_type):
|
||||||
|
"""Read repomd.xml and return the path to a specific data type's file."""
|
||||||
|
repomd_path = os.path.join(repodata_dir, "repomd.xml")
|
||||||
|
tree = ET.parse(repomd_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
|
||||||
|
for data in root.findall(f"{{{REPO_NS}}}data"):
|
||||||
|
if data.get("type") == data_type:
|
||||||
|
location = data.find(f"{{{REPO_NS}}}location")
|
||||||
|
if location is not None:
|
||||||
|
href = location.get("href", "")
|
||||||
|
return os.path.join(os.path.dirname(repodata_dir), href)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def open_compressed(path):
|
||||||
|
"""Open a gzip or zstd compressed file for reading."""
|
||||||
|
if path.endswith(".zst"):
|
||||||
|
result = subprocess.run(
|
||||||
|
["zstdcat", path], capture_output=True, check=True
|
||||||
|
)
|
||||||
|
import io
|
||||||
|
return io.BytesIO(result.stdout)
|
||||||
|
else:
|
||||||
|
return gzip.open(path, "rb")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_primary(repodata_dir):
|
||||||
|
"""Parse primary.xml.{gz,zst} and return package metadata."""
|
||||||
|
path = find_repodata_file(repodata_dir, "primary")
|
||||||
|
if not path:
|
||||||
|
print("error: primary metadata not found in repomd.xml", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
packages = {}
|
||||||
|
with open_compressed(path) as f:
|
||||||
|
tree = ET.parse(f)
|
||||||
|
|
||||||
|
for pkg in tree.getroot().findall(f"{{{RPM_NS}}}package"):
|
||||||
|
if pkg.get("type") != "rpm":
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = pkg.findtext(f"{{{RPM_NS}}}name", "")
|
||||||
|
version_el = pkg.find(f"{{{RPM_NS}}}version")
|
||||||
|
ver = version_el.get("ver", "") if version_el is not None else ""
|
||||||
|
rel = version_el.get("rel", "") if version_el is not None else ""
|
||||||
|
arch = pkg.findtext(f"{{{RPM_NS}}}arch", "")
|
||||||
|
|
||||||
|
size_el = pkg.find(f"{{{RPM_NS}}}size")
|
||||||
|
size = int(size_el.get("package", "0")) if size_el is not None else 0
|
||||||
|
|
||||||
|
time_el = pkg.find(f"{{{RPM_NS}}}time")
|
||||||
|
build_time = int(time_el.get("build", "0")) if time_el is not None else 0
|
||||||
|
|
||||||
|
location_el = pkg.find(f"{{{RPM_NS}}}location")
|
||||||
|
filename = os.path.basename(location_el.get("href", "")) if location_el is not None else ""
|
||||||
|
|
||||||
|
key = f"{name}-{ver}-{rel}"
|
||||||
|
packages[key] = {
|
||||||
|
"name": name,
|
||||||
|
"version": ver,
|
||||||
|
"release": rel,
|
||||||
|
"arch": arch,
|
||||||
|
"summary": pkg.findtext(f"{{{RPM_NS}}}summary", ""),
|
||||||
|
"size": size,
|
||||||
|
"buildTime": build_time,
|
||||||
|
"rpmFilename": filename,
|
||||||
|
"changelog": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
return packages
|
||||||
|
|
||||||
|
|
||||||
|
def parse_other(repodata_dir, packages):
|
||||||
|
"""Parse other.xml.gz and attach changelog entries to packages."""
|
||||||
|
path = find_repodata_file(repodata_dir, "other")
|
||||||
|
if not path:
|
||||||
|
return
|
||||||
|
|
||||||
|
with open_compressed(path) as f:
|
||||||
|
tree = ET.parse(f)
|
||||||
|
|
||||||
|
for pkg in tree.getroot().findall(f"{{{OTHER_NS}}}package"):
|
||||||
|
name = pkg.get("name", "")
|
||||||
|
version_el = pkg.find(f"{{{OTHER_NS}}}version")
|
||||||
|
ver = version_el.get("ver", "") if version_el is not None else ""
|
||||||
|
rel = version_el.get("rel", "") if version_el is not None else ""
|
||||||
|
key = f"{name}-{ver}-{rel}"
|
||||||
|
|
||||||
|
if key not in packages:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for entry in pkg.findall(f"{{{OTHER_NS}}}changelog"):
|
||||||
|
packages[key]["changelog"].append({
|
||||||
|
"author": entry.get("author", ""),
|
||||||
|
"date": int(entry.get("date", "0")),
|
||||||
|
"text": (entry.text or "").strip(),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repodata-dir",
|
||||||
|
required=True,
|
||||||
|
help="path to the repodata/ directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
required=True,
|
||||||
|
help="path to write packages.json",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
required=True,
|
||||||
|
help="public base URL for the repo (e.g. https://rpm.lair.cafe/fedora/43/x86_64)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
packages = parse_primary(args.repodata_dir)
|
||||||
|
parse_other(args.repodata_dir, packages)
|
||||||
|
|
||||||
|
manifest = {
|
||||||
|
"generated": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"baseUrl": args.base_url,
|
||||||
|
"packages": list(packages.values()),
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(manifest, f, indent=2)
|
||||||
|
|
||||||
|
print(f"wrote {len(packages)} packages to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
60
script/tp-smoke.sh
Executable file
60
script/tp-smoke.sh
Executable file
@@ -0,0 +1,60 @@
|
|||||||
|
#!/bin/env bash
|
||||||
|
#
|
||||||
|
# TP smoke test against a deployed neuron host.
|
||||||
|
#
|
||||||
|
# SSHes into the target host and runs `neuron --tp-smoke --tp-size N
|
||||||
|
# --cuda-devices ...` directly — no HTTP API involved. The smoke
|
||||||
|
# subcommand spawns N-1 worker subprocesses, joins them in an NCCL
|
||||||
|
# communicator, runs one AllReduce(Sum) of `1u32` across every rank, and
|
||||||
|
# verifies the observed sum equals world_size on every rank.
|
||||||
|
#
|
||||||
|
# This validates the lower-half of the TP stack (NCCL + IPC topology +
|
||||||
|
# subprocess lifecycle) without touching model load, inference, or HTTP.
|
||||||
|
# A failure here means the host cannot run any TP model and there is no
|
||||||
|
# point debugging the higher layers.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# script/tp-smoke.sh [host] [tp_size] [cuda_devices]
|
||||||
|
#
|
||||||
|
# Defaults:
|
||||||
|
# host = beast.hanzalova.internal (only fleet host with 2 GPUs)
|
||||||
|
# tp_size = 2
|
||||||
|
# cuda_devices = 0,1
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
HOST="${1:-beast.hanzalova.internal}"
|
||||||
|
TP_SIZE="${2:-2}"
|
||||||
|
CUDA_DEVICES="${3:-0,1}"
|
||||||
|
|
||||||
|
say() { printf '[%s] %s\n' "${HOST}" "$*" >&2; }
|
||||||
|
die() { say "FAIL: $*"; exit 1; }
|
||||||
|
|
||||||
|
say "running neuron --tp-smoke --tp-size ${TP_SIZE} --cuda-devices ${CUDA_DEVICES}"
|
||||||
|
|
||||||
|
# Run as root via sudo because:
|
||||||
|
# - cuda contexts under a user account require either the nvidia
|
||||||
|
# uvm/peer devices to be world-readable or the user to be in a
|
||||||
|
# priviliged group (neither is true on stock fc43);
|
||||||
|
# - the installed binary lives at /usr/bin/neuron with no setuid;
|
||||||
|
# Running through root is the simplest path that matches how
|
||||||
|
# systemd-managed neuron sees the GPUs in production.
|
||||||
|
#
|
||||||
|
# The smoke command is read-only — it allocates a transient NCCL comm
|
||||||
|
# and a 1u32 buffer per rank, then tears it all down.
|
||||||
|
if ! ssh -o BatchMode=yes "${HOST}" \
|
||||||
|
sudo /usr/bin/neuron \
|
||||||
|
--tp-smoke \
|
||||||
|
--tp-size "${TP_SIZE}" \
|
||||||
|
--cuda-devices "${CUDA_DEVICES}" 2>&1 | tee /tmp/tp-smoke-"${HOST}".log
|
||||||
|
then
|
||||||
|
die "tp-smoke exited non-zero (see /tmp/tp-smoke-${HOST}.log)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Final stdout line is `status=ok` on success.
|
||||||
|
if grep -q '^status=ok$' /tmp/tp-smoke-"${HOST}".log; then
|
||||||
|
say "PASS — NCCL handshake + AllReduce sanity check OK across ${TP_SIZE} ranks"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
die "no status=ok line in output"
|
||||||
|
fi
|
||||||
188
script/validate-neuron.sh
Executable file
188
script/validate-neuron.sh
Executable file
@@ -0,0 +1,188 @@
|
|||||||
|
#!/bin/env bash
|
||||||
|
#
|
||||||
|
# End-to-end smoke test for a deployed neuron.
|
||||||
|
#
|
||||||
|
# Confirms the daemon is reachable, loads a small public Qwen3 GGUF,
|
||||||
|
# fires a reasoning probe at /v1/chat/completions, and prints the
|
||||||
|
# answer. Used to validate the candle harness on a real GPU host
|
||||||
|
# before trusting it for production traffic, and as a regression test
|
||||||
|
# after pushing new neuron builds.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# script/validate-neuron.sh [host] [model_id] [quant] [tp_size]
|
||||||
|
#
|
||||||
|
# Defaults:
|
||||||
|
# host = beast.hanzalova.internal
|
||||||
|
# model_id = unsloth/Qwen3-0.6B-GGUF (official Qwen3-*-GGUF repos
|
||||||
|
# ship Q8_0 only; unsloth's mirror ships the full Q-spectrum
|
||||||
|
# including Q4_K_M)
|
||||||
|
# quant = Q4_K_M (empty = dense safetensors path)
|
||||||
|
# tp_size = unset (= 1 = single-GPU; pass 2 to drive the TP path)
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
HOST="${1:-beast.hanzalova.internal}"
|
||||||
|
MODEL_ID="${2:-unsloth/Qwen3-0.6B-GGUF}"
|
||||||
|
# `${3-Q4_K_M}` (no colon) only uses the default when the arg is
|
||||||
|
# UNSET — passing an explicit empty string drives the dense path.
|
||||||
|
QUANT="${3-Q4_K_M}"
|
||||||
|
# tp_size > 1 forces the dense path (TP requires safetensors) and adds
|
||||||
|
# `tensor_parallel: N` to the load payload. The harness picks device
|
||||||
|
# indices 0..N-1 by default; override by passing NEURON_DEVICES="0,1,..."
|
||||||
|
# in the environment.
|
||||||
|
TP_SIZE="${4-1}"
|
||||||
|
PORT="${NEURON_PORT:-13131}"
|
||||||
|
BASE="http://${HOST}:${PORT}"
|
||||||
|
|
||||||
|
# Reasoning probe — concrete, low-temperature answer that small models
|
||||||
|
# can still get right. "Paris" is a strong signal of basic competence
|
||||||
|
# beyond gibberish.
|
||||||
|
PROBE_PROMPT='What is the capital of Georgia (Caucasus)? Respond with the city name only, no punctuation.'
|
||||||
|
EXPECT_SUBSTR='Tbilisi'
|
||||||
|
# Qwen3 prepends <think>...</think> reasoning before the answer when the
|
||||||
|
# chat template enables thinking mode, which eats most of a small token
|
||||||
|
# budget. 256 leaves enough room for thinking + final answer.
|
||||||
|
MAX_TOKENS=256
|
||||||
|
|
||||||
|
# /models/load is synchronous — neuron blocks the response until the
|
||||||
|
# hf-hub download + (GGUF parse or safetensors mmap) + tensor
|
||||||
|
# materialisation is done. Small GGUF (0.6B-Q4_K_M, ~400 MB) is
|
||||||
|
# typically a minute on a warm cache, several on a cold one. A
|
||||||
|
# Qwen3.6-class dense model is ~54 GB and can easily take an hour to
|
||||||
|
# download cold over a residential link, so default high. Override
|
||||||
|
# with NEURON_LOAD_TIMEOUT=N (seconds) for smaller targets if you'd
|
||||||
|
# rather fail fast.
|
||||||
|
LOAD_TIMEOUT="${NEURON_LOAD_TIMEOUT:-3600}"
|
||||||
|
INFER_TIMEOUT="${NEURON_INFER_TIMEOUT:-120}"
|
||||||
|
|
||||||
|
# Status messages go to stderr so command substitutions like
|
||||||
|
# `raw=$(run_probe)` capture only the function's intended return value
|
||||||
|
# (an HTTP body), not the progress chatter.
|
||||||
|
say() { printf '[%s] %s\n' "${HOST}" "$*" >&2; }
|
||||||
|
die() { say "FAIL: $*"; exit 1; }
|
||||||
|
|
||||||
|
probe_health() {
|
||||||
|
curl --silent --fail --max-time 5 "${BASE}/health" >/dev/null \
|
||||||
|
|| die "neuron not reachable at ${BASE}/health"
|
||||||
|
}
|
||||||
|
|
||||||
|
list_loaded_ids() {
|
||||||
|
# The manifest is YAML and uses yq; HTTP responses are JSON and use
|
||||||
|
# jq directly. pip-yq parses input as YAML by default, which trips
|
||||||
|
# on JSON content that happens to look like YAML aliases (chatcmpl
|
||||||
|
# ids, escaped quotes inside `<think>...</think>` blocks, etc.).
|
||||||
|
curl --silent --fail "${BASE}/models" | jq -r '.[].id'
|
||||||
|
}
|
||||||
|
|
||||||
|
is_loaded() {
|
||||||
|
list_loaded_ids 2>/dev/null | grep -Fxq "${MODEL_ID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
trigger_load() {
|
||||||
|
# Build the per-rank CUDA device list as a JSON array. Either
|
||||||
|
# honour NEURON_DEVICES (`0,1,2`) verbatim or default to
|
||||||
|
# `[0, 1, ..., tp_size - 1]`.
|
||||||
|
local devices_json
|
||||||
|
if [[ -n "${NEURON_DEVICES:-}" ]]; then
|
||||||
|
devices_json=$(jq -n -c --arg s "${NEURON_DEVICES}" \
|
||||||
|
'$s | split(",") | map(tonumber)')
|
||||||
|
else
|
||||||
|
devices_json=$(jq -n -c --argjson n "${TP_SIZE}" '[range(0; $n)]')
|
||||||
|
fi
|
||||||
|
say "POST /models/load ${MODEL_ID} (quant=${QUANT:-<dense>}, tp=${TP_SIZE}, devices=${devices_json})"
|
||||||
|
say " (synchronous; may take a minute on first run while HF downloads)"
|
||||||
|
# Build the payload via jq so optional fields are omitted entirely
|
||||||
|
# when not in use. `tensor_parallel` is dropped when tp_size == 1;
|
||||||
|
# `quant` is dropped when empty. Both can coexist: tp_size > 1 +
|
||||||
|
# ISQ quant (q5k/q8_0/etc.) loads safetensors and quantizes the
|
||||||
|
# per-rank shard at load time. GGUF quants (Q4_K_M) are incompatible
|
||||||
|
# with TP — but the harness rejects that combination at load time
|
||||||
|
# rather than here.
|
||||||
|
local payload
|
||||||
|
local base
|
||||||
|
base=$(jq -n -c \
|
||||||
|
--arg id "${MODEL_ID}" \
|
||||||
|
--argjson devices "${devices_json}" \
|
||||||
|
'{model_id: $id, harness: "candle", devices: $devices}')
|
||||||
|
if [[ -n "${QUANT}" ]]; then
|
||||||
|
base=$(echo "${base}" | jq -c --arg q "${QUANT}" '. + {quant: $q}')
|
||||||
|
fi
|
||||||
|
if (( TP_SIZE > 1 )); then
|
||||||
|
base=$(echo "${base}" | jq -c --argjson tp "${TP_SIZE}" '. + {tensor_parallel: $tp}')
|
||||||
|
fi
|
||||||
|
payload="${base}"
|
||||||
|
# --write-out captures the response code on a separate line so we
|
||||||
|
# can surface a real diagnostic instead of relying on --fail.
|
||||||
|
local resp http_code body
|
||||||
|
resp=$(curl --silent --show-error --max-time "${LOAD_TIMEOUT}" \
|
||||||
|
--write-out '\n__HTTP__%{http_code}' \
|
||||||
|
-X POST "${BASE}/models/load" \
|
||||||
|
-H 'content-type: application/json' \
|
||||||
|
--data "${payload}") || die "curl /models/load failed: $?"
|
||||||
|
http_code=$(echo "${resp}" | grep -oP '(?<=__HTTP__)\d+$' | tail -1)
|
||||||
|
body=$(echo "${resp}" | sed '$ s/__HTTP__.*$//')
|
||||||
|
if [[ "${http_code}" != "200" ]]; then
|
||||||
|
die "load returned HTTP ${http_code}: ${body}"
|
||||||
|
fi
|
||||||
|
say "load returned ${http_code}: ${body}"
|
||||||
|
}
|
||||||
|
|
||||||
|
run_probe() {
|
||||||
|
say "POST /v1/chat/completions (probe: ${PROBE_PROMPT})"
|
||||||
|
local payload
|
||||||
|
payload=$(jq -n -c \
|
||||||
|
--arg model "${MODEL_ID}" \
|
||||||
|
--arg content "${PROBE_PROMPT}" \
|
||||||
|
--argjson tokens "${MAX_TOKENS}" \
|
||||||
|
'{
|
||||||
|
model: $model,
|
||||||
|
messages: [{role: "user", content: $content}],
|
||||||
|
temperature: 0.1,
|
||||||
|
max_tokens: $tokens
|
||||||
|
}')
|
||||||
|
local resp http_code body
|
||||||
|
resp=$(curl --silent --show-error --max-time "${INFER_TIMEOUT}" \
|
||||||
|
--write-out '\n__HTTP__%{http_code}' \
|
||||||
|
-X POST "${BASE}/v1/chat/completions" \
|
||||||
|
-H 'content-type: application/json' \
|
||||||
|
--data "${payload}") || die "curl /v1/chat/completions failed: $?"
|
||||||
|
http_code=$(echo "${resp}" | grep -oP '(?<=__HTTP__)\d+$' | tail -1)
|
||||||
|
body=$(echo "${resp}" | sed '$ s/__HTTP__.*$//')
|
||||||
|
if [[ "${http_code}" != "200" ]]; then
|
||||||
|
die "inference returned HTTP ${http_code}: ${body}"
|
||||||
|
fi
|
||||||
|
echo "${body}"
|
||||||
|
}
|
||||||
|
|
||||||
|
say "validating neuron at ${BASE}"
|
||||||
|
probe_health
|
||||||
|
say "/health OK"
|
||||||
|
|
||||||
|
if is_loaded; then
|
||||||
|
say "${MODEL_ID} already loaded"
|
||||||
|
else
|
||||||
|
trigger_load
|
||||||
|
fi
|
||||||
|
|
||||||
|
raw=$(run_probe)
|
||||||
|
echo "---"
|
||||||
|
# Dump the raw JSON. Don't pipe through `yq -r '.'` — yq's default
|
||||||
|
# YAML output mode chokes on JSON strings that contain `<` (and the
|
||||||
|
# `<think>` markers Qwen3 emits during reasoning are a perfect
|
||||||
|
# example). The targeted `yq -r '.path'` calls below work fine
|
||||||
|
# because jq's path filter mode bypasses the YAML re-emit.
|
||||||
|
echo "${raw}"
|
||||||
|
echo "---"
|
||||||
|
|
||||||
|
content=$(echo "${raw}" | jq -r '.choices[0].message.content // empty')
|
||||||
|
if [[ -z "${content}" ]]; then
|
||||||
|
die "no content in chat completion response"
|
||||||
|
fi
|
||||||
|
say "assistant said: ${content}"
|
||||||
|
|
||||||
|
if echo "${content}" | grep -qiF "${EXPECT_SUBSTR}"; then
|
||||||
|
say "PASS — response contains expected substring '${EXPECT_SUBSTR}'"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
die "response did not contain '${EXPECT_SUBSTR}'"
|
||||||
|
fi
|
||||||
Reference in New Issue
Block a user