Compare commits
55 Commits
v0.1.9
...
12549c9aed
| Author | SHA1 | Date | |
|---|---|---|---|
|
12549c9aed
|
|||
|
46527d7804
|
|||
|
8d3194f992
|
|||
|
5436af9c73
|
|||
|
8e882c0757
|
|||
|
93421f48e2
|
|||
|
05e15f3597
|
|||
|
da068ded6d
|
|||
|
2a7ede0232
|
|||
|
18ae3c30ee
|
|||
|
1a0400131e
|
|||
|
1866b99a89
|
|||
|
60176e7c2e
|
|||
|
602e8e1471
|
|||
|
e9d0a75dd5
|
|||
|
6cf87e328f
|
|||
|
f9f5fa41b6
|
|||
|
ed4d71db09
|
|||
|
39010c779f
|
|||
|
57d7ef8d3c
|
|||
|
0e9671dd7d
|
|||
|
e29c9e35f0
|
|||
|
8a2334eacb
|
|||
|
aad314cdfa
|
|||
|
6779b7526a
|
|||
|
84f5662df1
|
|||
|
249c9442e8
|
|||
|
5e17081fb4
|
|||
|
03bed93fee
|
|||
|
4a5211d830
|
|||
|
6d2dc5ff1a
|
|||
|
b713dbe669
|
|||
|
5c957d08ec
|
|||
|
729317d1ef
|
|||
|
5c2bd1a1da
|
|||
|
3cccc2c56b
|
|||
|
7f797b0265
|
|||
|
5a0360c1d5
|
|||
|
472c0e8737
|
|||
|
|
b9d8e30058 | ||
|
25f75fe552
|
|||
|
3f94c50817
|
|||
|
3e1fb60076
|
|||
|
|
9bf987888c | ||
|
abe4ff7ccc
|
|||
|
7c3390a4e1
|
|||
|
2ff062da0e
|
|||
|
|
357f858a29 | ||
|
556e5293dc
|
|||
|
1d90238b01
|
|||
|
d99b25fb8a
|
|||
|
034da319f1
|
|||
|
|
7ece281617 | ||
|
3bb5b3c425
|
|||
|
|
9fa51ad874 |
342
.gitea/workflows/build-prerelease.yml
Normal file
342
.gitea/workflows/build-prerelease.yml
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
name: build-prerelease
|
||||||
|
|
||||||
|
# Manually-dispatched workflow that builds CUDA-flavoured neuron binaries
|
||||||
|
# (and a single cortex binary), packages each as a Fedora RPM, signs
|
||||||
|
# them, and publishes to the `unstable` channel at rpm.lair.cafe.
|
||||||
|
#
|
||||||
|
# Trigger from the Gitea UI: Actions → build-prerelease → Run workflow.
|
||||||
|
# Optionally provide a `ref` to build from a non-default branch.
|
||||||
|
#
|
||||||
|
# The published packages are versioned as e.g.
|
||||||
|
# helexa-neuron-blackwell-0.1.16-0.1.20260518T140530.gitabcdef0.fc43.x86_64
|
||||||
|
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
|
||||||
|
# commit time (s) commit sha
|
||||||
|
# so they sort BELOW the eventual 0.1.16-1 stable release, and so two
|
||||||
|
# commits on the same day are still strictly ordered by their commit
|
||||||
|
# timestamps (rather than by RPM-vercmp's alpha-vs-digit precedence
|
||||||
|
# on the SHA fragment).
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Auto-build on every push to main so the unstable channel tracks
|
||||||
|
# head without a manual dispatch step.
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
# Manual dispatch still available to build from a non-main ref.
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: "Git ref to build (branch / tag / commit). Defaults to the workflow's branch."
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
# Share the group with ci.yml so the two workflows can't run
|
||||||
|
# concurrently on the same `rust` runner (act reuses the workspace
|
||||||
|
# cache and races destroy each other's build files mid-compile).
|
||||||
|
# cancel-in-progress=false → workflows queue; if a newer push lands,
|
||||||
|
# the older run is still picked up by ci.yml's own ref-keyed
|
||||||
|
# concurrency (same group, queued).
|
||||||
|
group: cortex-runner-pool-${{ github.ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_INCREMENTAL: "0"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
prepare:
|
||||||
|
name: Resolve version stamps
|
||||||
|
runs-on: rust
|
||||||
|
outputs:
|
||||||
|
version: ${{ steps.info.outputs.version }}
|
||||||
|
release: ${{ steps.info.outputs.release }}
|
||||||
|
short_sha: ${{ steps.info.outputs.short_sha }}
|
||||||
|
commit_timestamp: ${{ steps.info.outputs.commit_timestamp }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- id: info
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml)
|
||||||
|
SHORT_SHA=$(git rev-parse --short=7 HEAD)
|
||||||
|
# Second-precise commit timestamp gives the release stamp a
|
||||||
|
# strictly monotonic numeric prefix. The earlier %Y%m%d-only
|
||||||
|
# form let same-day builds be ordered by RPM's rpmvercmp
|
||||||
|
# rules over the SHA, which is non-chronological — e.g.
|
||||||
|
# "git602e8e1" sorts newer than "gitf9f5fa4" purely because
|
||||||
|
# rpmvercmp ranks digit-prefixed segments above alpha ones.
|
||||||
|
# The SHA stays only as a debug identifier; sort order is
|
||||||
|
# decided entirely by the timestamp.
|
||||||
|
COMMIT_TIMESTAMP=$(git log -1 --format=%cd --date=format:%Y%m%d%H%M%S HEAD)
|
||||||
|
RELEASE="0.1.${COMMIT_TIMESTAMP}.git${SHORT_SHA}"
|
||||||
|
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "release=${RELEASE}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "commit_timestamp=${COMMIT_TIMESTAMP}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
build-cortex:
|
||||||
|
name: Build cortex binary
|
||||||
|
needs: prepare
|
||||||
|
# runner-rust image already provides rust/cargo/clippy/rustfmt via
|
||||||
|
# dnf — no rustup install step needed.
|
||||||
|
runs-on: rust
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Build cortex (release)
|
||||||
|
run: cargo build --release -p cortex-cli
|
||||||
|
|
||||||
|
- name: Stage binary
|
||||||
|
run: |
|
||||||
|
mkdir --parents artifacts
|
||||||
|
cp target/release/cortex artifacts/cortex
|
||||||
|
./artifacts/cortex --version || true
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: cortex-fc43
|
||||||
|
path: artifacts/cortex
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
|
build-neuron:
|
||||||
|
name: Build neuron-${{ matrix.flavour }}
|
||||||
|
needs: prepare
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- flavour: ampere
|
||||||
|
compute_cap: "86"
|
||||||
|
runner: cuda-13.0
|
||||||
|
cuda_home: /usr/local/cuda-13.0
|
||||||
|
build_jobs: 8
|
||||||
|
nvcc_threads: 4
|
||||||
|
cargo_features: "cuda cudnn flash-attn"
|
||||||
|
- flavour: ada
|
||||||
|
compute_cap: "89"
|
||||||
|
runner: cuda-13.0
|
||||||
|
cuda_home: /usr/local/cuda-13.0
|
||||||
|
build_jobs: 8
|
||||||
|
nvcc_threads: 4
|
||||||
|
cargo_features: "cuda cudnn flash-attn"
|
||||||
|
- flavour: blackwell
|
||||||
|
compute_cap: "120"
|
||||||
|
runner: cuda-13.0
|
||||||
|
cuda_home: /usr/local/cuda-13.0
|
||||||
|
build_jobs: 8
|
||||||
|
nvcc_threads: 4
|
||||||
|
cargo_features: "cuda cudnn flash-attn"
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Build neuron with CUDA (${{ matrix.flavour }})
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
export PATH="${{ matrix.cuda_home }}/bin:${PATH}"
|
||||||
|
export LD_LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LD_LIBRARY_PATH:-}"
|
||||||
|
export LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LIBRARY_PATH:-}"
|
||||||
|
cargo build --release -p neuron --features "${{ matrix.cargo_features }}"
|
||||||
|
env:
|
||||||
|
CUDA_COMPUTE_CAP: ${{ matrix.compute_cap }}
|
||||||
|
CARGO_BUILD_JOBS: ${{ matrix.build_jobs }}
|
||||||
|
NVCC_THREADS: ${{ matrix.nvcc_threads }}
|
||||||
|
|
||||||
|
- name: Stage binary
|
||||||
|
run: |
|
||||||
|
mkdir --parents artifacts
|
||||||
|
cp target/release/neuron artifacts/neuron-${{ matrix.flavour }}
|
||||||
|
file "artifacts/neuron-${{ matrix.flavour }}"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: neuron-${{ matrix.flavour }}-fc43
|
||||||
|
path: artifacts/neuron-${{ matrix.flavour }}
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
|
package-cortex:
|
||||||
|
name: Package cortex RPM
|
||||||
|
needs: [prepare, build-cortex]
|
||||||
|
runs-on: rpm
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
name: cortex-fc43
|
||||||
|
path: artifacts/
|
||||||
|
|
||||||
|
- name: Build RPM
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
rm -f ~/.rpmmacros
|
||||||
|
rpmdev-setuptree
|
||||||
|
cp artifacts/cortex ~/rpmbuild/SOURCES/
|
||||||
|
cp data/cortex.service ~/rpmbuild/SOURCES/
|
||||||
|
cp data/cortex-sysusers.conf ~/rpmbuild/SOURCES/
|
||||||
|
cp data/cortex-firewalld.xml ~/rpmbuild/SOURCES/
|
||||||
|
cp cortex.example.toml ~/rpmbuild/SOURCES/
|
||||||
|
cp models.example.toml ~/rpmbuild/SOURCES/
|
||||||
|
cp LICENSE ~/rpmbuild/SOURCES/
|
||||||
|
rpmbuild -bb rpm/cortex-prerelease.spec \
|
||||||
|
--define "cortex_version ${{ needs.prepare.outputs.version }}" \
|
||||||
|
--define "cortex_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||||
|
--undefine dist \
|
||||||
|
--define "dist .fc43"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: rpm-cortex-fc43
|
||||||
|
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
package-neuron:
|
||||||
|
name: Package helexa-neuron-${{ matrix.flavour }} RPM
|
||||||
|
needs: [prepare, build-neuron]
|
||||||
|
runs-on: rpm
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- flavour: ampere
|
||||||
|
- flavour: ada
|
||||||
|
- flavour: blackwell
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
name: neuron-${{ matrix.flavour }}-fc43
|
||||||
|
path: artifacts/
|
||||||
|
|
||||||
|
- name: Build RPM
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
rm -f ~/.rpmmacros
|
||||||
|
rpmdev-setuptree
|
||||||
|
cp artifacts/neuron-${{ matrix.flavour }} ~/rpmbuild/SOURCES/
|
||||||
|
cp data/neuron.service ~/rpmbuild/SOURCES/
|
||||||
|
cp data/neuron-sysusers.conf ~/rpmbuild/SOURCES/
|
||||||
|
cp data/neuron-firewalld.xml ~/rpmbuild/SOURCES/
|
||||||
|
cp neuron.example.toml ~/rpmbuild/SOURCES/
|
||||||
|
cp LICENSE ~/rpmbuild/SOURCES/
|
||||||
|
rpmbuild -bb rpm/helexa-neuron-prerelease.spec \
|
||||||
|
--define "neuron_version ${{ needs.prepare.outputs.version }}" \
|
||||||
|
--define "neuron_flavour ${{ matrix.flavour }}" \
|
||||||
|
--define "neuron_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||||
|
--undefine dist \
|
||||||
|
--define "dist .fc43"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: rpm-neuron-${{ matrix.flavour }}-fc43
|
||||||
|
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish to rpm.lair.cafe (unstable)
|
||||||
|
needs: [package-cortex, package-neuron]
|
||||||
|
runs-on: rpm
|
||||||
|
concurrency:
|
||||||
|
group: rpm-publish
|
||||||
|
cancel-in-progress: false
|
||||||
|
env:
|
||||||
|
RPM_REPO_HOST: oolon.kosherinata.internal
|
||||||
|
FEDORA_VERSION: "43"
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Download all built RPMs
|
||||||
|
uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
path: rpms/
|
||||||
|
pattern: rpm-*-fc43
|
||||||
|
|
||||||
|
- name: Flatten RPM artifacts
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
find rpms/ -name '*.rpm' -exec mv --target-directory=rpms/ {} +
|
||||||
|
find rpms/ -mindepth 1 -type d -empty -delete
|
||||||
|
ls -la rpms/
|
||||||
|
|
||||||
|
- name: Check for sequoia-sq
|
||||||
|
run: |
|
||||||
|
if ! command -v sq &> /dev/null; then
|
||||||
|
echo "ERROR: sequoia-sq is not installed. Install with: sudo dnf install sequoia-sq"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Import signing key
|
||||||
|
env:
|
||||||
|
# Pass secrets via env so values stay out of the rendered shell
|
||||||
|
# script (which Gitea includes in step logs). Template
|
||||||
|
# expansion of ${{ secrets.X }} inside `run:` writes the literal
|
||||||
|
# value into the script and depends on Gitea's log masker to
|
||||||
|
# scrub it — fragile for multi-line keys.
|
||||||
|
RPM_SIGNING_KEY: ${{ secrets.RPM_SIGNING_KEY }}
|
||||||
|
RPM_SIGNING_KEY_ID: ${{ secrets.RPM_SIGNING_KEY_ID }}
|
||||||
|
run: |
|
||||||
|
echo "$RPM_SIGNING_KEY" | gpg --batch --import
|
||||||
|
fpr=$(gpg --batch --with-colons --list-keys "$RPM_SIGNING_KEY_ID" | awk -F: '/^fpr:/ { print $10; exit }')
|
||||||
|
echo "${fpr}:6:" | gpg --batch --import-ownertrust
|
||||||
|
sed "s/@GPG_NAME@/$RPM_SIGNING_KEY_ID/" rpm/rpmmacros > ~/.rpmmacros
|
||||||
|
|
||||||
|
- name: Sign RPMs
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
for rpm in rpms/*.rpm; do
|
||||||
|
echo "signing ${rpm}..."
|
||||||
|
rpm --addsign "${rpm}"
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Set up SSH for rsync
|
||||||
|
run: |
|
||||||
|
install --directory --mode 700 ~/.ssh
|
||||||
|
echo "${RSYNC_SSH_KEY}" | install --mode 600 /dev/stdin ~/.ssh/id_ed25519
|
||||||
|
env:
|
||||||
|
RSYNC_SSH_KEY: ${{ secrets.RSYNC_SSH_KEY }}
|
||||||
|
|
||||||
|
- name: Test SSH connectivity
|
||||||
|
run: |
|
||||||
|
ssh -o StrictHostKeyChecking=accept-new "gitea_ci@${RPM_REPO_HOST}" exit
|
||||||
|
|
||||||
|
- name: Ensure unstable repo directory exists
|
||||||
|
run: |
|
||||||
|
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||||
|
"mkdir --parents /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||||
|
|
||||||
|
- name: Sync RPMs to unstable repo
|
||||||
|
run: |
|
||||||
|
rsync \
|
||||||
|
--archive \
|
||||||
|
--verbose \
|
||||||
|
--chmod D755,F644 \
|
||||||
|
rpms/*.rpm \
|
||||||
|
"gitea_ci@${RPM_REPO_HOST}:/var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/"
|
||||||
|
|
||||||
|
- name: Update unstable repo metadata
|
||||||
|
run: |
|
||||||
|
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||||
|
"cd /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable && createrepo_c --update ."
|
||||||
|
|
||||||
|
- name: Generate packages.json manifest
|
||||||
|
run: |
|
||||||
|
scp script/generate-packages-json.py "gitea_ci@${RPM_REPO_HOST}:/tmp/"
|
||||||
|
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||||
|
"python3 /tmp/generate-packages-json.py \
|
||||||
|
--repodata-dir /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/repodata \
|
||||||
|
--output /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/packages.json \
|
||||||
|
--base-url https://rpm.lair.cafe/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||||
@@ -7,6 +7,16 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
|
# Share a concurrency group with build-prerelease.yml so the two
|
||||||
|
# workflows don't race on the same `rust` runner workspace (act's
|
||||||
|
# /root/.cache/act/<hash>/hostexecutor/ is shared across concurrent
|
||||||
|
# jobs and one job's checkout step nukes another's in-flight build
|
||||||
|
# files). cancel-in-progress=false → they queue; same-ref pushes
|
||||||
|
# coalesce per workflow via cancel-in-progress on each.
|
||||||
|
concurrency:
|
||||||
|
group: cortex-runner-pool-${{ github.ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
env:
|
env:
|
||||||
CARGO_INCREMENTAL: "0"
|
CARGO_INCREMENTAL: "0"
|
||||||
RUSTC_WRAPPER: sccache
|
RUSTC_WRAPPER: sccache
|
||||||
@@ -16,56 +26,47 @@ env:
|
|||||||
SCCACHE_S3_USE_SSL: "false"
|
SCCACHE_S3_USE_SSL: "false"
|
||||||
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||||
|
# fmt, clippy, and test all run in parallel on the same `rust` runner
|
||||||
|
# and would otherwise share /root/.cache/act/<hash>/hostexecutor/target/,
|
||||||
|
# racing each other's cargo temp files (.tmpXXXXXX) and failing builds
|
||||||
|
# mid-compile. Give each job its own target directory so the invocations
|
||||||
|
# don't collide. sccache still backs the actual rustc cache, so the
|
||||||
|
# rebuild penalty is small.
|
||||||
|
CARGO_TARGET_DIR: target-${{ github.job }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check:
|
fmt:
|
||||||
name: Format, lint, build, test
|
name: Format
|
||||||
runs-on: fedora
|
runs-on: rust
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
- run: cargo fmt --check --all
|
||||||
|
|
||||||
- name: Cache cargo registry and target
|
clippy:
|
||||||
uses: actions/cache@v4
|
name: Clippy
|
||||||
with:
|
runs-on: rust
|
||||||
path: |
|
steps:
|
||||||
~/.cargo/bin
|
- uses: actions/checkout@v4
|
||||||
~/.cargo/registry/index
|
- run: cargo clippy --workspace -- -D warnings
|
||||||
~/.cargo/registry/cache
|
- run: sccache --show-stats
|
||||||
~/.cargo/git/db
|
|
||||||
target
|
|
||||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-cargo-
|
|
||||||
|
|
||||||
- name: Ensure sccache with S3 support
|
test:
|
||||||
env:
|
name: Test
|
||||||
RUSTC_WRAPPER: ""
|
runs-on: rust
|
||||||
run: |
|
steps:
|
||||||
if sccache --version 2>/dev/null && sccache --show-stats 2>/dev/null; then
|
- uses: actions/checkout@v4
|
||||||
echo "sccache with S3 support already installed"
|
- run: cargo test --workspace
|
||||||
else
|
- run: sccache --show-stats
|
||||||
cargo install sccache --features s3 --locked
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Check formatting
|
|
||||||
run: cargo fmt --check --all
|
|
||||||
|
|
||||||
- name: Clippy
|
|
||||||
run: cargo clippy --workspace -- -D warnings
|
|
||||||
|
|
||||||
- name: Test
|
|
||||||
run: cargo test --workspace
|
|
||||||
|
|
||||||
- name: Show sccache stats
|
|
||||||
run: sccache --show-stats
|
|
||||||
|
|
||||||
srpm-cortex:
|
srpm-cortex:
|
||||||
name: Build cortex SRPM
|
name: Build cortex SRPM
|
||||||
runs-on: fedora
|
runs-on: rpm
|
||||||
needs: check
|
needs: [fmt, clippy, test]
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Determine version
|
- name: Determine version
|
||||||
id: version
|
id: version
|
||||||
@@ -79,6 +80,12 @@ jobs:
|
|||||||
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
||||||
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
||||||
|
|
||||||
|
- name: Generate changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: cortex.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
- name: Generate source tarball
|
- name: Generate source tarball
|
||||||
run: |
|
run: |
|
||||||
set -ex
|
set -ex
|
||||||
@@ -113,11 +120,13 @@ jobs:
|
|||||||
|
|
||||||
srpm-neuron:
|
srpm-neuron:
|
||||||
name: Build neuron SRPM
|
name: Build neuron SRPM
|
||||||
runs-on: fedora
|
runs-on: rpm
|
||||||
needs: check
|
needs: [fmt, clippy, test]
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Determine version
|
- name: Determine version
|
||||||
id: version
|
id: version
|
||||||
@@ -129,31 +138,37 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
VERSION="${{ steps.version.outputs.VERSION }}"
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
||||||
sed -i "s/^Version:.*/Version: ${VERSION}/" neuron.spec
|
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
|
||||||
|
|
||||||
|
- name: Generate changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: helexa-neuron.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
- name: Generate source tarball
|
- name: Generate source tarball
|
||||||
run: |
|
run: |
|
||||||
set -ex
|
set -ex
|
||||||
VERSION="${{ steps.version.outputs.VERSION }}"
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
tar czf /tmp/neuron-${VERSION}.tar.gz \
|
tar czf /tmp/helexa-neuron-${VERSION}.tar.gz \
|
||||||
--transform "s,^\.,neuron-${VERSION}," \
|
--transform "s,^\.,helexa-neuron-${VERSION}," \
|
||||||
--exclude='./target' \
|
--exclude='./target' \
|
||||||
--exclude='./.git' \
|
--exclude='./.git' \
|
||||||
--exclude='*.tar.gz' \
|
--exclude='*.tar.gz' \
|
||||||
--exclude='*.src.rpm' \
|
--exclude='*.src.rpm' \
|
||||||
.
|
.
|
||||||
mv /tmp/neuron-${VERSION}.tar.gz .
|
mv /tmp/helexa-neuron-${VERSION}.tar.gz .
|
||||||
|
|
||||||
- name: Vendor Rust dependencies
|
- name: Vendor Rust dependencies
|
||||||
run: |
|
run: |
|
||||||
VERSION="${{ steps.version.outputs.VERSION }}"
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
cargo vendor vendor/
|
cargo vendor vendor/
|
||||||
tar czf neuron-${VERSION}-vendor.tar.gz vendor/
|
tar czf helexa-neuron-${VERSION}-vendor.tar.gz vendor/
|
||||||
rm -rf vendor/
|
rm -rf vendor/
|
||||||
|
|
||||||
- name: Build SRPM
|
- name: Build SRPM
|
||||||
run: |
|
run: |
|
||||||
rpmbuild -bs neuron.spec \
|
rpmbuild -bs helexa-neuron.spec \
|
||||||
--define "_sourcedir $(pwd)" \
|
--define "_sourcedir $(pwd)" \
|
||||||
--define "_srcrpmdir $(pwd)"
|
--define "_srcrpmdir $(pwd)"
|
||||||
|
|
||||||
@@ -165,7 +180,7 @@ jobs:
|
|||||||
|
|
||||||
copr-cortex:
|
copr-cortex:
|
||||||
name: Publish cortex to COPR
|
name: Publish cortex to COPR
|
||||||
runs-on: fedora
|
runs-on: fedora-43
|
||||||
needs: srpm-cortex
|
needs: srpm-cortex
|
||||||
steps:
|
steps:
|
||||||
- name: Download SRPM
|
- name: Download SRPM
|
||||||
@@ -176,13 +191,13 @@ jobs:
|
|||||||
- name: Publish to COPR
|
- name: Publish to COPR
|
||||||
uses: https://git.lair.cafe/actions/copr-publish@v1
|
uses: https://git.lair.cafe/actions/copr-publish@v1
|
||||||
with:
|
with:
|
||||||
project: helexa/cortex
|
project: helexa/helexa
|
||||||
srpm: "*.src.rpm"
|
srpm: "*.src.rpm"
|
||||||
copr-config: ${{ secrets.COPR_CONFIG }}
|
copr-config: ${{ secrets.COPR_CONFIG }}
|
||||||
|
|
||||||
copr-neuron:
|
copr-neuron:
|
||||||
name: Publish neuron to COPR
|
name: Publish neuron to COPR
|
||||||
runs-on: fedora
|
runs-on: fedora-43
|
||||||
needs: srpm-neuron
|
needs: srpm-neuron
|
||||||
steps:
|
steps:
|
||||||
- name: Download SRPM
|
- name: Download SRPM
|
||||||
@@ -193,31 +208,53 @@ jobs:
|
|||||||
- name: Publish to COPR
|
- name: Publish to COPR
|
||||||
uses: https://git.lair.cafe/actions/copr-publish@v1
|
uses: https://git.lair.cafe/actions/copr-publish@v1
|
||||||
with:
|
with:
|
||||||
project: helexa/neuron
|
project: helexa/helexa
|
||||||
srpm: "*.src.rpm"
|
srpm: "*.src.rpm"
|
||||||
copr-config: ${{ secrets.COPR_CONFIG }}
|
copr-config: ${{ secrets.COPR_CONFIG }}
|
||||||
|
|
||||||
bump-version:
|
bump-version:
|
||||||
name: Bump version in source
|
name: Bump version in source
|
||||||
runs-on: fedora
|
runs-on: rust
|
||||||
needs: [copr-cortex, copr-neuron]
|
needs: [copr-cortex, copr-neuron]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Stamp version and push
|
- name: Determine version
|
||||||
|
id: version
|
||||||
|
run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Stamp version
|
||||||
|
run: |
|
||||||
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
|
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
||||||
|
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
||||||
|
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
|
||||||
|
cargo check --workspace 2>/dev/null || true
|
||||||
|
|
||||||
|
- name: Generate cortex changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: cortex.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
|
- name: Generate helexa-neuron changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: helexa-neuron.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
|
- name: Commit and push
|
||||||
env:
|
env:
|
||||||
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
|
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
VERSION="${GITHUB_REF#refs/tags/v}"
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
|
||||||
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
|
||||||
sed -i "s/^Version:.*/Version: ${VERSION}/" neuron.spec
|
|
||||||
cargo check --workspace 2>/dev/null || true
|
|
||||||
git config user.name "Gitea Actions"
|
git config user.name "Gitea Actions"
|
||||||
git config user.email "actions@git.lair.cafe"
|
git config user.email "actions@git.lair.cafe"
|
||||||
git add Cargo.toml Cargo.lock cortex.spec neuron.spec
|
git add Cargo.toml Cargo.lock cortex.spec helexa-neuron.spec
|
||||||
if git diff --cached --quiet; then
|
if git diff --cached --quiet; then
|
||||||
echo "Version already at ${VERSION}"
|
echo "Nothing to commit for ${VERSION}"
|
||||||
else
|
else
|
||||||
git commit -m "chore: bump version to ${VERSION}"
|
git commit -m "chore: bump version to ${VERSION}"
|
||||||
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/helexa/cortex.git"
|
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/helexa/cortex.git"
|
||||||
|
|||||||
116
CLAUDE.md
116
CLAUDE.md
@@ -125,7 +125,8 @@ automatically. Clippy warnings must be resolved, not suppressed with
|
|||||||
- One or more GPU nodes running mistral.rs on port 8080
|
- One or more GPU nodes running mistral.rs on port 8080
|
||||||
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
|
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
|
||||||
- Each node runs `mistralrs serve` on port 8080
|
- Each node runs `mistralrs serve` on port 8080
|
||||||
- Gateway listens on port 8000 (API) and 9100 (metrics)
|
- Gateway listens on port 31313 (API) and 31314 (metrics)
|
||||||
|
- neuron listens on port 13131 on each GPU host
|
||||||
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
|
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
|
||||||
|
|
||||||
## Conventions
|
## Conventions
|
||||||
@@ -380,7 +381,7 @@ processes (one process per loaded model, each on its own port).
|
|||||||
|
|
||||||
## neuron API
|
## neuron API
|
||||||
|
|
||||||
neuron exposes an HTTP API on port 9090 that cortex polls and calls.
|
neuron exposes an HTTP API on port 13131 that cortex polls and calls.
|
||||||
|
|
||||||
```
|
```
|
||||||
GET /discovery
|
GET /discovery
|
||||||
@@ -424,8 +425,8 @@ endpoint. cortex.toml shrinks to:
|
|||||||
|
|
||||||
```toml
|
```toml
|
||||||
[gateway]
|
[gateway]
|
||||||
listen = "0.0.0.0:8000"
|
listen = "0.0.0.0:31313"
|
||||||
metrics_listen = "0.0.0.0:9100"
|
metrics_listen = "0.0.0.0:31314"
|
||||||
|
|
||||||
[eviction]
|
[eviction]
|
||||||
strategy = "lru"
|
strategy = "lru"
|
||||||
@@ -433,15 +434,15 @@ defrag_after_cycles = 50
|
|||||||
|
|
||||||
[[neurons]]
|
[[neurons]]
|
||||||
name = "beast"
|
name = "beast"
|
||||||
endpoint = "http://beast.hanzalova.internal:9090"
|
endpoint = "http://beast.hanzalova.internal:13131"
|
||||||
|
|
||||||
[[neurons]]
|
[[neurons]]
|
||||||
name = "benjy"
|
name = "benjy"
|
||||||
endpoint = "http://benjy.kosherinata.internal:9090"
|
endpoint = "http://benjy.hanzalova.internal:13131"
|
||||||
|
|
||||||
[[neurons]]
|
[[neurons]]
|
||||||
name = "quadbrat"
|
name = "quadbrat"
|
||||||
endpoint = "http://quadbrat.hanzalova.internal:9090"
|
endpoint = "http://quadbrat.hanzalova.internal:13131"
|
||||||
```
|
```
|
||||||
|
|
||||||
On startup and periodically, cortex calls `GET /discovery` and
|
On startup and periodically, cortex calls `GET /discovery` and
|
||||||
@@ -521,7 +522,7 @@ cortex/
|
|||||||
│ │ └── metrics.rs # prometheus exporter (unchanged)
|
│ │ └── metrics.rs # prometheus exporter (unchanged)
|
||||||
│ ├── neuron/ # node plane (replaces cortex-agent)
|
│ ├── neuron/ # node plane (replaces cortex-agent)
|
||||||
│ │ └── src/
|
│ │ └── src/
|
||||||
│ │ ├── main.rs # binary entrypoint, axum server on :9090
|
│ │ ├── main.rs # binary entrypoint, axum server on :13131
|
||||||
│ │ ├── discovery.rs # nvidia-smi, device enumeration
|
│ │ ├── discovery.rs # nvidia-smi, device enumeration
|
||||||
│ │ ├── health.rs # runtime GPU polling
|
│ │ ├── health.rs # runtime GPU polling
|
||||||
│ │ ├── api.rs # HTTP handlers for /discovery, /models, etc.
|
│ │ ├── api.rs # HTTP handlers for /discovery, /models, etc.
|
||||||
@@ -595,70 +596,65 @@ placement matching can be added incrementally.
|
|||||||
Completed. Both packages have RPM specs, systemd units, and example configs.
|
Completed. Both packages have RPM specs, systemd units, and example configs.
|
||||||
CI builds parallel SRPMs on tag push and publishes to separate COPR repos.
|
CI builds parallel SRPMs on tag push and publishes to separate COPR repos.
|
||||||
|
|
||||||
- `cortex.spec` → `helexa/cortex` COPR: binary, systemd unit, config files
|
- `cortex.spec` — installs the `cortex` binary. Package name keeps the
|
||||||
- `neuron.spec` → `helexa/neuron` COPR: binary, systemd unit, config
|
short `cortex` because no Fedora package collides with it.
|
||||||
|
- `helexa-neuron.spec` — installs the `neuron` binary under package name
|
||||||
|
`helexa-neuron`. Renamed from bare `neuron` to avoid collision with
|
||||||
|
Fedora's NEURON neural-simulation package
|
||||||
|
(https://src.fedoraproject.org/rpms/neuron); binary, systemd unit,
|
||||||
|
system user, and config dir all stay named `neuron` since those are
|
||||||
|
project-local contexts.
|
||||||
- `data/cortex.service`, `data/neuron.service` — systemd units
|
- `data/cortex.service`, `data/neuron.service` — systemd units
|
||||||
- `cortex.example.toml`, `neuron.example.toml`, `models.example.toml`
|
- `cortex.example.toml`, `neuron.example.toml`, `models.example.toml`
|
||||||
- CI: parallel `srpm-cortex` + `srpm-neuron` jobs, then parallel COPR publish
|
- CI: parallel `srpm-cortex` + `srpm-neuron` jobs, then parallel COPR
|
||||||
|
publish to a single project `helexa/helexa` hosting both packages.
|
||||||
|
|
||||||
Install:
|
Install:
|
||||||
```sh
|
```sh
|
||||||
dnf copr enable helexa/cortex && dnf install cortex # gateway host
|
dnf copr enable helexa/helexa
|
||||||
dnf copr enable helexa/neuron && dnf install neuron # GPU nodes
|
dnf install cortex # gateway host
|
||||||
|
dnf install helexa-neuron # GPU nodes
|
||||||
```
|
```
|
||||||
|
|
||||||
### Phase 11: llama.cpp harness stub
|
## 2026-05-18 addendum: candle-native pivot
|
||||||
|
|
||||||
**Goal:** Prove the harness abstraction works with a second engine.
|
Phases 11 (llama.cpp harness) and 12 (mistral.rs COPR) below are
|
||||||
|
**superseded**. The project no longer treats mistral.rs or llama.cpp as
|
||||||
|
dependencies — both are conceptually out of scope. neuron becomes a
|
||||||
|
candle-native inference daemon, with `Harness` retained as an
|
||||||
|
internal seam for adding future engines (vision/audio/diffusion) but
|
||||||
|
its only implementation being in-process candle.
|
||||||
|
|
||||||
**Steps:**
|
The full staged plan for this pivot lives at
|
||||||
1. `crates/neuron/src/harness/llamacpp.rs` — implement the `Harness`
|
`~/.claude/plans/create-a-more-aggressive-calm-naur.md`. Summary:
|
||||||
trait for llama.cpp's `llama-server`.
|
|
||||||
- `start()` — launch `llama-server` with the correct model path,
|
|
||||||
`--port`, `--n-gpu-layers`, `--tensor-split` args. Track the
|
|
||||||
child process.
|
|
||||||
- `stop()` — send SIGTERM to the child process.
|
|
||||||
- `list_models()` — llama-server serves one model per process, so
|
|
||||||
return a single-element list.
|
|
||||||
- `load_model()` — start a new llama-server process for this model.
|
|
||||||
- `unload_model()` — stop the process.
|
|
||||||
- `inference_endpoint()` — return `http://localhost:{assigned_port}`.
|
|
||||||
2. Port allocation: neuron assigns ports from a range (e.g. 8100-8199)
|
|
||||||
to llama-server instances.
|
|
||||||
3. Register in `HarnessRegistry` when configured:
|
|
||||||
```toml
|
|
||||||
[[harnesses]]
|
|
||||||
name = "llamacpp"
|
|
||||||
binary = "/usr/local/bin/llama-server"
|
|
||||||
port_range = [8100, 8199]
|
|
||||||
```
|
|
||||||
4. Tests: mock llama-server (simple HTTP server returning canned
|
|
||||||
responses), test load/unload/endpoint lifecycle.
|
|
||||||
|
|
||||||
**Done when:** A model with `harness = "llamacpp"` in `models.toml` can
|
- **Stage 1 (this commit):** delete `mistralrs.rs` and `llamacpp.rs`,
|
||||||
be loaded and served through cortex. Tests pass with mock llama-server.
|
scaffold inert `CandleHarness`, drop `endpoint`/`systemd_unit` from
|
||||||
|
`HarnessConfig`, default no-op `start`/`stop` on the `Harness` trait.
|
||||||
|
- **Stages 2–4:** wire up candle model load/unload (quantized Qwen3
|
||||||
|
first), add OpenAI-compatible inference endpoint in neuron, then SSE
|
||||||
|
streaming.
|
||||||
|
- **Stages 5–6:** load-on-activation (default models in config) and
|
||||||
|
unload-on-deactivation (graceful shutdown).
|
||||||
|
- **Stages 7–8:** multi-GPU tensor parallelism and broader model/quant
|
||||||
|
coverage.
|
||||||
|
|
||||||
### Phase 12 (lower priority): mistral.rs COPR packaging
|
Sections of this document that describe mistral.rs HTTP behaviour
|
||||||
|
("mistral.rs API gotchas") are retained as historical context for
|
||||||
|
Phases 1–10 — they document what was true while the project depended
|
||||||
|
on mistral.rs. They do not describe current behaviour.
|
||||||
|
|
||||||
**Goal:** Fedora RPMs for mistral.rs built against specific CUDA versions.
|
---
|
||||||
|
|
||||||
**Steps:**
|
### Phase 11 (superseded): llama.cpp harness stub
|
||||||
1. `mistralrs-cuda.spec` — RPM spec that clones a pinned mistral.rs git
|
|
||||||
tag, builds with `--features cuda`, links against the system CUDA
|
|
||||||
toolkit. Produces `mistralrs-cuda13-server` (CUDA 13.x / sm_120) and
|
|
||||||
`mistralrs-cuda12-server` (CUDA 12.x / sm_89). Install binary to
|
|
||||||
`/usr/local/bin/mistralrs`.
|
|
||||||
2. COPR build config: enable the NVIDIA CUDA repo as a build dependency.
|
|
||||||
Pin the CUDA toolkit version in `BuildRequires`.
|
|
||||||
3. Gitea Actions or manual workflow: bump the mistral.rs tag in the spec,
|
|
||||||
trigger COPR rebuild.
|
|
||||||
4. neuron's mistralrs harness config references which binary/package
|
|
||||||
provides the mistral.rs binary. neuron could warn at startup if the
|
|
||||||
installed mistral.rs CUDA version doesn't match the discovered driver.
|
|
||||||
|
|
||||||
**Done when:** `dnf install mistralrs-cuda13-server` on beast provides a
|
~~Originally planned as a second engine to prove the harness
|
||||||
working `mistralrs` binary built for Blackwell GPUs. `dnf install
|
abstraction.~~ Replaced by the candle harness work in the 2026-05-18
|
||||||
mistralrs-cuda12-server` on benjy provides one built for Ada GPUs.
|
addendum above. llama.cpp's any-model/any-hardware breadth is no
|
||||||
|
longer in scope for helexa.
|
||||||
|
|
||||||
This is a separate repo/spec — not part of the cortex workspace — but
|
### Phase 12 (superseded): mistral.rs COPR packaging
|
||||||
tightly coupled operationally. Track it as a sibling project.
|
|
||||||
|
~~Originally planned to ship CUDA-versioned mistral.rs RPMs.~~ Replaced
|
||||||
|
by the candle harness work in the 2026-05-18 addendum above. With
|
||||||
|
mistral.rs out of the dependency tree, there is nothing to package.
|
||||||
|
|||||||
1612
Cargo.lock
generated
1612
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ members = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.1.7"
|
version = "0.1.16"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
license = "GPL-3.0-or-later"
|
license = "GPL-3.0-or-later"
|
||||||
repository = "https://git.lair.cafe/helexa/cortex"
|
repository = "https://git.lair.cafe/helexa/cortex"
|
||||||
@@ -27,7 +27,7 @@ serde = { version = "1", features = ["derive"] }
|
|||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
|
|
||||||
# http client (for proxying to mistralrs backends)
|
# http client (for proxying to neuron backends)
|
||||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||||
|
|
||||||
# observability
|
# observability
|
||||||
|
|||||||
105
README.md
105
README.md
@@ -1,22 +1,23 @@
|
|||||||
# cortex
|
# cortex
|
||||||
|
|
||||||
A Rust reverse-proxy and fleet management layer for multi-node
|
A Rust reverse-proxy and fleet management layer for multi-node GPU inference
|
||||||
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
|
clusters. Cortex sits in front of one or more `neuron` daemons (each running
|
||||||
|
candle-based inference on a local GPU host) and presents a unified OpenAI +
|
||||||
|
Anthropic compatible API surface.
|
||||||
|
|
||||||
## Problem
|
## Problem
|
||||||
|
|
||||||
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
||||||
model affinities) requires a unified API surface that:
|
model affinities) requires a unified API surface that:
|
||||||
|
|
||||||
- Presents a **single `/v1/models` catalogue** merging every model across every
|
- Presents a **single `/v1/models` catalogue** merging every model that can be
|
||||||
node.
|
served by any neuron in the fleet.
|
||||||
- **Routes requests** to the correct node based on where a model is loaded (or
|
- **Routes requests** to the correct node based on where a model is loaded
|
||||||
*can* be loaded).
|
(or can be loaded), handling cold-load and eviction transparently.
|
||||||
- Manages **model lifecycle** — unload cold models, reload on demand, pin
|
- Manages **model lifecycle** — load on demand, unload cold models, pin
|
||||||
critical ones — using the mistral.rs
|
critical ones — by calling each neuron's `/models/{load,unload}` API.
|
||||||
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
|
|
||||||
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
||||||
every client in the homelab speaks whichever dialect it prefers.
|
every client speaks whichever dialect it prefers.
|
||||||
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
||||||
them as Prometheus counters/histograms.
|
them as Prometheus counters/histograms.
|
||||||
|
|
||||||
@@ -30,18 +31,17 @@ model affinities) requires a unified API surface that:
|
|||||||
└────────────────┴──────┬───────┴───────────────┘
|
└────────────────┴──────┬───────┴───────────────┘
|
||||||
│
|
│
|
||||||
┌──────────▼──────────┐
|
┌──────────▼──────────┐
|
||||||
│ cortex │
|
│ cortex │
|
||||||
│ (cortex-gateway) │
|
│ (cortex-gateway) │
|
||||||
│ │
|
│ │
|
||||||
│ Router · Metrics │
|
│ Router · Metrics │
|
||||||
│ Evictor · Translate│
|
│ Evictor · Translate│
|
||||||
└──┬──────┬────────┬──┘
|
└──┬──────┬────────┬──┘
|
||||||
│ │ │
|
│ │ │
|
||||||
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
||||||
│ gpu-large │ │gpu-med │ │ gpu-small │
|
│ neuron │ │ neuron │ │ neuron │
|
||||||
│ mistralrs │ │mistral │ │ mistralrs │
|
│ :13131 │ │ :13131 │ │ :13131 │
|
||||||
│ serve │ │rs serve│ │ serve │
|
│ candle │ │ candle │ │ candle │
|
||||||
│ :8080 │ │ :8080 │ │ :8080 │
|
|
||||||
└───────────┘ └────────┘ └───────────┘
|
└───────────┘ └────────┘ └───────────┘
|
||||||
private network (.internal)
|
private network (.internal)
|
||||||
```
|
```
|
||||||
@@ -50,70 +50,48 @@ model affinities) requires a unified API surface that:
|
|||||||
|
|
||||||
| Crate | Purpose |
|
| Crate | Purpose |
|
||||||
|---|---|
|
|---|---|
|
||||||
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
|
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic envelopes, harness trait, discovery types |
|
||||||
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
|
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, poller, metrics exporter |
|
||||||
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
|
| `neuron` | Per-node daemon: GPU discovery, in-process candle inference, model lifecycle API |
|
||||||
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
||||||
|
|
||||||
## Node setup
|
## Node setup
|
||||||
|
|
||||||
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
|
Each GPU node runs `neuron` (listening on `:13131`). Neuron uses
|
||||||
declared but start **unloaded** — mistral.rs lazy-loads on first request and
|
huggingface/candle for in-process inference — there is no external
|
||||||
the gateway can explicitly unload/reload via the HTTP API.
|
inference subprocess to manage.
|
||||||
|
|
||||||
Example node systemd unit:
|
The neuron RPM (`helexa-neuron`) ships a systemd unit:
|
||||||
|
|
||||||
```ini
|
```sh
|
||||||
# /etc/systemd/system/mistralrs.service
|
dnf copr enable helexa/helexa
|
||||||
[Unit]
|
dnf install helexa-neuron
|
||||||
Description=mistral.rs inference server
|
systemctl enable --now neuron
|
||||||
After=network-online.target
|
|
||||||
Wants=network-online.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
Type=simple
|
|
||||||
ExecStart=/usr/local/bin/mistralrs serve \
|
|
||||||
--from-config /etc/mistralrs/config.toml \
|
|
||||||
--port 8080
|
|
||||||
Restart=on-failure
|
|
||||||
RestartSec=5
|
|
||||||
Environment=CUDA_VISIBLE_DEVICES=0,1
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Gateway config
|
## Gateway config
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
# cortex.toml
|
# /etc/cortex/cortex.toml
|
||||||
[gateway]
|
[gateway]
|
||||||
listen = "0.0.0.0:8000"
|
listen = "0.0.0.0:31313"
|
||||||
metrics_listen = "0.0.0.0:9100"
|
metrics_listen = "0.0.0.0:31314"
|
||||||
|
|
||||||
[eviction]
|
[eviction]
|
||||||
strategy = "lru" # lru | priority
|
strategy = "lru" # lru | priority
|
||||||
defrag_after_cycles = 50
|
defrag_after_cycles = 50
|
||||||
|
|
||||||
[[nodes]]
|
[[neurons]]
|
||||||
name = "gpu-large"
|
name = "beast"
|
||||||
endpoint = "http://gpu-large.internal:8080"
|
endpoint = "http://beast.internal:13131"
|
||||||
vram_mb = 49_152 # e.g. 2x RTX 4090
|
|
||||||
pinned = ["your-org/large-model"]
|
|
||||||
|
|
||||||
[[nodes]]
|
[[neurons]]
|
||||||
name = "gpu-medium"
|
name = "benjy"
|
||||||
endpoint = "http://gpu-medium.internal:8080"
|
endpoint = "http://benjy.internal:13131"
|
||||||
vram_mb = 24_576 # e.g. RTX 4090
|
|
||||||
pinned = ["your-org/medium-model"]
|
|
||||||
|
|
||||||
[[nodes]]
|
|
||||||
name = "gpu-small"
|
|
||||||
endpoint = "http://gpu-small.internal:8080"
|
|
||||||
vram_mb = 12_288 # e.g. RTX 3060
|
|
||||||
pinned = ["your-org/embedding-model"]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Model placement profiles live in `models.toml` — see `models.example.toml`.
|
||||||
|
|
||||||
## Building
|
## Building
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
@@ -131,19 +109,20 @@ cargo clippy --workspace -- -D warnings # warnings are errors
|
|||||||
cargo test --workspace # all tests must pass
|
cargo test --workspace # all tests must pass
|
||||||
```
|
```
|
||||||
|
|
||||||
Tagged releases (`v*`) additionally build an SRPM and publish to COPR.
|
Tagged releases (`v*`) additionally build SRPMs for both `cortex` and
|
||||||
|
`helexa-neuron` and publish to COPR.
|
||||||
|
|
||||||
## Running
|
## Running
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
# start the gateway
|
# start the gateway
|
||||||
cortex serve --config cortex.toml
|
cortex serve --config /etc/cortex/cortex.toml
|
||||||
|
|
||||||
# check fleet status
|
# check fleet status
|
||||||
cortex status
|
cortex status
|
||||||
|
|
||||||
# list all models across nodes
|
# list all models across nodes
|
||||||
curl http://localhost:8000/v1/models
|
curl http://localhost:31313/v1/models
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|||||||
30
asset/manifest.yml
Normal file
30
asset/manifest.yml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Helexa fleet manifest.
|
||||||
|
#
|
||||||
|
# Drives rolling deploys via script/deploy.sh and serves as the source
|
||||||
|
# of truth for which hosts run cortex vs neuron, and which CUDA
|
||||||
|
# compute-capability flavour each neuron host needs.
|
||||||
|
#
|
||||||
|
# Flavour ↔ NVIDIA generation ↔ compute cap:
|
||||||
|
# ampere sm_86 (RTX 30 series — e.g. 3060)
|
||||||
|
# ada sm_89 (RTX 40 series — e.g. 4090)
|
||||||
|
# blackwell sm_120 (RTX 50 series — e.g. 5090)
|
||||||
|
#
|
||||||
|
# The flavour determines which RPM is installed on a given neuron host:
|
||||||
|
# helexa-neuron-<flavour>. Only one flavour may be installed at a time
|
||||||
|
# (the packages Conflict: with each other).
|
||||||
|
|
||||||
|
cortex:
|
||||||
|
host: hanzalova.internal
|
||||||
|
|
||||||
|
neurons:
|
||||||
|
- host: beast.hanzalova.internal
|
||||||
|
flavour: blackwell
|
||||||
|
gpu: "2x RTX 5090"
|
||||||
|
|
||||||
|
- host: benjy.hanzalova.internal
|
||||||
|
flavour: ada
|
||||||
|
gpu: "RTX 4090"
|
||||||
|
|
||||||
|
- host: quadbrat.hanzalova.internal
|
||||||
|
flavour: ampere
|
||||||
|
gpu: "RTX 3060"
|
||||||
@@ -3,22 +3,22 @@
|
|||||||
# Copy to cortex.toml and adjust for your environment.
|
# Copy to cortex.toml and adjust for your environment.
|
||||||
#
|
#
|
||||||
# Environment variable overrides use CORTEX_ prefix with __ separators:
|
# Environment variable overrides use CORTEX_ prefix with __ separators:
|
||||||
# CORTEX_GATEWAY__LISTEN=0.0.0.0:9000
|
# CORTEX_GATEWAY__LISTEN=0.0.0.0:31313
|
||||||
|
|
||||||
[gateway]
|
[gateway]
|
||||||
listen = "0.0.0.0:8000"
|
listen = "0.0.0.0:31313"
|
||||||
metrics_listen = "0.0.0.0:9100"
|
metrics_listen = "0.0.0.0:31314"
|
||||||
|
|
||||||
[eviction]
|
[eviction]
|
||||||
strategy = "lru"
|
strategy = "lru"
|
||||||
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
|
# Restart neurons after this many load/unload cycles to defragment VRAM.
|
||||||
# Set to 0 to disable.
|
# Set to 0 to disable.
|
||||||
defrag_after_cycles = 50
|
defrag_after_cycles = 50
|
||||||
|
|
||||||
# -- Nodes ---------------------------------------------------------------
|
# -- Nodes ---------------------------------------------------------------
|
||||||
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
|
# Each [[nodes]] entry declares a neuron daemon in the fleet.
|
||||||
# Models are discovered by polling the node's /v1/models endpoint.
|
# Models are discovered by polling the neuron's /models endpoint.
|
||||||
# Pinned models are never evicted.
|
# Pinned models (see models.toml) are never evicted.
|
||||||
|
|
||||||
[[nodes]]
|
[[nodes]]
|
||||||
name = "gpu-large"
|
name = "gpu-large"
|
||||||
|
|||||||
40
cortex.spec
40
cortex.spec
@@ -1,5 +1,5 @@
|
|||||||
Name: cortex
|
Name: cortex
|
||||||
Version: 0.1.7
|
Version: 0.1.16
|
||||||
Release: 1%{?dist}
|
Release: 1%{?dist}
|
||||||
Summary: Inference gateway for multi-node GPU clusters
|
Summary: Inference gateway for multi-node GPU clusters
|
||||||
|
|
||||||
@@ -21,12 +21,16 @@ BuildRequires: systemd-rpm-macros
|
|||||||
|
|
||||||
Requires(pre): shadow-utils
|
Requires(pre): shadow-utils
|
||||||
Requires: systemd
|
Requires: systemd
|
||||||
|
Requires: firewalld-filesystem
|
||||||
|
|
||||||
# rpm's sysusers provides-generator only emits versioned user(cortex) when
|
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
|
||||||
# the u-line has GECOS/home/shell fields. %attr(,,cortex) in %files emits
|
# from our .service file and emits Requires: user(cortex)/group(cortex).
|
||||||
# an unversioned Requires: user(cortex), so we provide it explicitly.
|
# rpm's sysusers provides-generator emits the unversioned form for groups
|
||||||
|
# but only a versioned user(cortex) = <base64> for users with GECOS/home/
|
||||||
|
# shell. Provide the unversioned user(cortex) explicitly so dnf can resolve
|
||||||
|
# the auto-generated Requires. Without this, dnf5 silently filters the
|
||||||
|
# package and reports "Nothing to do".
|
||||||
Provides: user(cortex)
|
Provides: user(cortex)
|
||||||
Provides: group(cortex)
|
|
||||||
|
|
||||||
%description
|
%description
|
||||||
Cortex is a Rust reverse-proxy that sits in front of multiple inference
|
Cortex is a Rust reverse-proxy that sits in front of multiple inference
|
||||||
@@ -53,9 +57,10 @@ cargo build --release -p cortex-cli
|
|||||||
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
|
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
|
||||||
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
|
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
|
||||||
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
|
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
|
||||||
install -dm750 %{buildroot}%{_sysconfdir}/cortex
|
install -Dm644 data/cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
install -Dm640 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
|
install -dm755 %{buildroot}%{_sysconfdir}/cortex
|
||||||
install -Dm640 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
|
||||||
|
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
%pre
|
%pre
|
||||||
%sysusers_create_compat %{_builddir}/%{name}-%{version}/data/cortex-sysusers.conf
|
%sysusers_create_compat %{_builddir}/%{name}-%{version}/data/cortex-sysusers.conf
|
||||||
@@ -75,10 +80,21 @@ install -Dm640 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
|||||||
%{_bindir}/cortex
|
%{_bindir}/cortex
|
||||||
%{_unitdir}/cortex.service
|
%{_unitdir}/cortex.service
|
||||||
%{_sysusersdir}/cortex.conf
|
%{_sysusersdir}/cortex.conf
|
||||||
%dir %attr(750,root,cortex) %{_sysconfdir}/cortex
|
%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
%config(noreplace) %attr(640,root,cortex) %{_sysconfdir}/cortex/cortex.toml
|
%dir %{_sysconfdir}/cortex
|
||||||
%config(noreplace) %attr(640,root,cortex) %{_sysconfdir}/cortex/models.toml
|
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
|
||||||
|
%config(noreplace) %{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
%changelog
|
%changelog
|
||||||
* Tue Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
|
||||||
|
- chore: ignore local deploy script
|
||||||
|
- chore: move default ports out of common-collision ranges
|
||||||
|
- ci: drop actions/cache for cargo registry and target
|
||||||
|
|
||||||
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
|
||||||
|
- ci: publish both packages to a single helexa/helexa COPR project
|
||||||
|
- fix(rpm): rename neuron package to helexa-neuron
|
||||||
|
- ci: commit generated %changelog entries back to main
|
||||||
|
|
||||||
|
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
||||||
- Initial package
|
- Initial package
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use tracing_subscriber::EnvFilter;
|
|||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "cortex")]
|
#[command(name = "cortex")]
|
||||||
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
|
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
struct Cli {
|
struct Cli {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
@@ -23,7 +23,7 @@ enum Commands {
|
|||||||
/// Print the fleet status (models, nodes, health).
|
/// Print the fleet status (models, nodes, health).
|
||||||
Status {
|
Status {
|
||||||
/// Gateway API endpoint to query.
|
/// Gateway API endpoint to query.
|
||||||
#[arg(short, long, default_value = "http://localhost:8000")]
|
#[arg(short, long, default_value = "http://localhost:31313")]
|
||||||
endpoint: String,
|
endpoint: String,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
//!
|
//!
|
||||||
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
||||||
//! The gateway accepts these, translates to OpenAI format, proxies to
|
//! The gateway accepts these, translates to OpenAI format, proxies to
|
||||||
//! mistral.rs, then translates the response back.
|
//! the inference backend (neuron), then translates the response back.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ fn default_models_path() -> String {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct GatewaySettings {
|
pub struct GatewaySettings {
|
||||||
/// Address to listen on for API requests (e.g. "0.0.0.0:8000")
|
/// Address to listen on for API requests (e.g. "0.0.0.0:31313")
|
||||||
pub listen: String,
|
pub listen: String,
|
||||||
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:9100")
|
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:31314")
|
||||||
pub metrics_listen: String,
|
pub metrics_listen: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ pub enum EvictionStrategy {
|
|||||||
pub struct NeuronEndpoint {
|
pub struct NeuronEndpoint {
|
||||||
/// Human-readable node name (e.g. "beast")
|
/// Human-readable node name (e.g. "beast")
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090")
|
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131")
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,8 +70,8 @@ impl Default for GatewayConfig {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
gateway: GatewaySettings {
|
gateway: GatewaySettings {
|
||||||
listen: "0.0.0.0:8000".into(),
|
listen: "0.0.0.0:31313".into(),
|
||||||
metrics_listen: "0.0.0.0:9100".into(),
|
metrics_listen: "0.0.0.0:31314".into(),
|
||||||
},
|
},
|
||||||
eviction: EvictionSettings {
|
eviction: EvictionSettings {
|
||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ use async_trait::async_trait;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Configuration for a harness instance on a neuron.
|
/// Configuration for a harness instance on a neuron.
|
||||||
|
///
|
||||||
|
/// All current harnesses are in-process (candle); per-harness tuning
|
||||||
|
/// (cache paths, device policies, etc.) lives in dedicated config
|
||||||
|
/// blocks rather than on this struct.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct HarnessConfig {
|
pub struct HarnessConfig {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// Base URL of the harness (e.g. "http://localhost:8080" for mistral.rs).
|
|
||||||
pub endpoint: Option<String>,
|
|
||||||
/// Systemd unit name, if the harness is managed via systemd.
|
|
||||||
pub systemd_unit: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Health status of a harness process.
|
/// Health status of a harness process.
|
||||||
@@ -47,16 +47,24 @@ pub struct ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// What an inference harness must do, from neuron's perspective.
|
/// What an inference harness must do, from neuron's perspective.
|
||||||
|
///
|
||||||
|
/// All current harnesses are in-process — they share neuron's address
|
||||||
|
/// space and lifecycle. `start`/`stop` therefore default to no-ops; a
|
||||||
|
/// future process-supervising harness would override them.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Harness: Send + Sync {
|
pub trait Harness: Send + Sync {
|
||||||
/// Human-readable name (e.g. "mistralrs", "llamacpp", "comfyui").
|
/// Human-readable name (e.g. "candle").
|
||||||
fn name(&self) -> &str;
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
/// Start the harness process if it is not already running.
|
/// Start the harness. Default no-op for in-process harnesses.
|
||||||
async fn start(&self, config: &HarnessConfig) -> Result<()>;
|
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Stop the harness process gracefully.
|
/// Stop the harness. Default no-op for in-process harnesses.
|
||||||
async fn stop(&self) -> Result<()>;
|
async fn stop(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Health check. Returns the harness process status.
|
/// Health check. Returns the harness process status.
|
||||||
async fn health(&self) -> HarnessHealth;
|
async fn health(&self) -> HarnessHealth;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use std::collections::HashMap;
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct NodeState {
|
pub struct NodeState {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090").
|
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131").
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
pub healthy: bool,
|
pub healthy: bool,
|
||||||
pub models: HashMap<String, ModelEntry>,
|
pub models: HashMap<String, ModelEntry>,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
||||||
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
||||||
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
||||||
//! extension field mistral.rs supports.
|
//! extension field a backend might support.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -22,7 +22,7 @@ pub struct ChatCompletionRequest {
|
|||||||
pub max_tokens: Option<u64>,
|
pub max_tokens: Option<u64>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stream: Option<bool>,
|
pub stream: Option<bool>,
|
||||||
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
|
/// All other fields (tools, response_format, backend extensions, etc.)
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub extra: Value,
|
pub extra: Value,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//! Streaming HTTP reverse proxy to mistral.rs backends.
|
//! Streaming HTTP reverse proxy to neuron backends.
|
||||||
//!
|
//!
|
||||||
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
||||||
//! The proxy captures timing information for metrics but does not
|
//! The proxy captures timing information for metrics but does not
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ use tokio::net::TcpListener;
|
|||||||
/// - GET /models/:id/endpoint (returns the inference URL)
|
/// - GET /models/:id/endpoint (returns the inference URL)
|
||||||
/// - POST /models/unload (accepts unload requests)
|
/// - POST /models/unload (accepts unload requests)
|
||||||
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
|
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
|
||||||
|
///
|
||||||
/// Returns the neuron base URL.
|
/// Returns the neuron base URL.
|
||||||
pub async fn spawn_mock_neuron() -> String {
|
pub async fn spawn_mock_neuron() -> String {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
@@ -54,7 +55,7 @@ pub async fn spawn_mock_neuron() -> String {
|
|||||||
|
|
||||||
async fn mock_neuron_list_models() -> Json<Value> {
|
async fn mock_neuron_list_models() -> Json<Value> {
|
||||||
Json(json!([
|
Json(json!([
|
||||||
{"id": "test-model", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
{"id": "test-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
||||||
]))
|
]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ use std::sync::Arc;
|
|||||||
async fn test_poller_discovers_models() {
|
async fn test_poller_discovers_models() {
|
||||||
// Mock neuron reports 2 models via /models endpoint (neuron format).
|
// Mock neuron reports 2 models via /models endpoint (neuron format).
|
||||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
{"id": "model-a", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
{"id": "model-a", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
||||||
{"id": "model-b", "harness": "mistralrs", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
{"id": "model-b", "harness": "candle", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
||||||
]))
|
]))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -63,8 +63,8 @@ async fn test_poller_discovers_models() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_updates_gateway_models_endpoint() {
|
async fn test_poller_updates_gateway_models_endpoint() {
|
||||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
{"id": "model-x", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
{"id": "model-x", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||||
{"id": "model-y", "harness": "mistralrs", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
{"id": "model-y", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
||||||
]))
|
]))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -152,8 +152,8 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_removes_stale_models() {
|
async fn test_poller_removes_stale_models() {
|
||||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||||
{"id": "drop-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
{"id": "drop-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||||
]))
|
]))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -183,7 +183,7 @@ async fn test_poller_removes_stale_models() {
|
|||||||
|
|
||||||
// New mock with only one model.
|
// New mock with only one model.
|
||||||
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
|
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||||
]))
|
]))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|||||||
@@ -51,18 +51,18 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
chunks.len() >= chunk_count + 1,
|
chunks.len() > chunk_count,
|
||||||
"expected at least {} chunks (got {}): {:?}",
|
"expected more than {} chunks (got {}): {:?}",
|
||||||
chunk_count + 1,
|
chunk_count,
|
||||||
chunks.len(),
|
chunks.len(),
|
||||||
chunks,
|
chunks,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
||||||
|
|
||||||
for i in 0..chunk_count {
|
for (i, chunk) in chunks.iter().enumerate().take(chunk_count) {
|
||||||
let chunk_json: serde_json::Value =
|
let chunk_json: serde_json::Value =
|
||||||
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
|
serde_json::from_str(chunk).expect("chunk should be valid JSON");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
chunk_json["choices"][0]["delta"]["content"],
|
chunk_json["choices"][0]["delta"]["content"],
|
||||||
format!("token{i}")
|
format!("token{i}")
|
||||||
|
|||||||
@@ -12,6 +12,34 @@ path = "src/lib.rs"
|
|||||||
name = "neuron"
|
name = "neuron"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
# Enables CUDA acceleration in candle and the cudarc/nccl bindings the
|
||||||
|
# TP worker pool uses. Without this feature, candle compiles for CPU
|
||||||
|
# only, Device::new_cuda calls fall back to CPU, and TP Init/sanity
|
||||||
|
# requests return Error{kind="cuda_feature_not_enabled"}.
|
||||||
|
cuda = [
|
||||||
|
"candle-core/cuda",
|
||||||
|
"candle-core/nccl",
|
||||||
|
"candle-nn/cuda",
|
||||||
|
"candle-transformers/cuda",
|
||||||
|
"dep:cudarc",
|
||||||
|
]
|
||||||
|
# Use cuDNN for convolution / attention kernels. Requires CUDA.
|
||||||
|
cudnn = [
|
||||||
|
"cuda",
|
||||||
|
"candle-core/cudnn",
|
||||||
|
"candle-nn/cudnn",
|
||||||
|
"candle-transformers/cudnn",
|
||||||
|
]
|
||||||
|
# FlashAttention kernels. Requires CUDA.
|
||||||
|
flash-attn = [
|
||||||
|
"cuda",
|
||||||
|
"candle-transformers/flash-attn",
|
||||||
|
]
|
||||||
|
# Reserved for GPU-only integration tests in later stages.
|
||||||
|
cuda-integration = ["cuda"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
cortex-core.workspace = true
|
cortex-core.workspace = true
|
||||||
tokio.workspace = true
|
tokio.workspace = true
|
||||||
@@ -24,9 +52,25 @@ tracing-subscriber.workspace = true
|
|||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
clap.workspace = true
|
clap.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
tokio-stream.workspace = true
|
||||||
figment.workspace = true
|
figment.workspace = true
|
||||||
toml.workspace = true
|
toml.workspace = true
|
||||||
|
|
||||||
|
# candle for in-process inference. CUDA support is gated behind the
|
||||||
|
# crate's `cuda` feature (default off) so the workspace builds on
|
||||||
|
# non-CUDA hosts and CI runners.
|
||||||
|
candle-core = "0.10.2"
|
||||||
|
candle-nn = "0.10.2"
|
||||||
|
candle-transformers = "0.10.2"
|
||||||
|
# Direct dep on cudarc (matching candle's transitive version) so the
|
||||||
|
# TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on
|
||||||
|
# the `cuda` feature; same toolchain requirement as candle's CUDA path.
|
||||||
|
cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] }
|
||||||
|
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
|
||||||
|
hf-hub = { version = "0.4", features = ["tokio"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { workspace = true, features = ["test-util"] }
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
reqwest.workspace = true
|
reqwest.workspace = true
|
||||||
|
|||||||
@@ -1,23 +1,33 @@
|
|||||||
//! HTTP API handlers for the neuron daemon.
|
//! HTTP API handlers for the neuron daemon.
|
||||||
|
|
||||||
use crate::harness::HarnessRegistry;
|
use crate::harness::HarnessRegistry;
|
||||||
|
use crate::harness::candle::{CandleHarness, InferenceError};
|
||||||
use crate::health::HealthCache;
|
use crate::health::HealthCache;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use axum::extract::{Path, State};
|
use axum::extract::{Path, State};
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::{IntoResponse, Json};
|
use axum::response::{IntoResponse, Json};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||||
use cortex_core::harness::ModelSpec;
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use cortex_core::openai::ChatCompletionRequest;
|
||||||
|
use futures::stream::{self, StreamExt};
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
|
use std::convert::Infallible;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
|
||||||
/// Shared state for the neuron HTTP server.
|
/// Shared state for the neuron HTTP server.
|
||||||
pub struct NeuronState {
|
pub struct NeuronState {
|
||||||
pub discovery: DiscoveryResponse,
|
pub discovery: DiscoveryResponse,
|
||||||
pub health_cache: Arc<HealthCache>,
|
pub health_cache: Arc<HealthCache>,
|
||||||
pub registry: RwLock<HarnessRegistry>,
|
pub registry: RwLock<HarnessRegistry>,
|
||||||
|
/// Typed handle to the candle harness for inference routes. Cached at
|
||||||
|
/// startup so `/v1/chat/completions` doesn't have to hold the registry
|
||||||
|
/// read lock or perform dyn-Trait dispatch per request.
|
||||||
|
pub candle: Option<Arc<CandleHarness>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the neuron API router.
|
/// Build the neuron API router.
|
||||||
@@ -29,6 +39,7 @@ pub fn neuron_routes() -> Router<Arc<NeuronState>> {
|
|||||||
.route("/models/load", post(load_model))
|
.route("/models/load", post(load_model))
|
||||||
.route("/models/unload", post(unload_model))
|
.route("/models/unload", post(unload_model))
|
||||||
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
||||||
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
||||||
@@ -45,7 +56,7 @@ async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse
|
|||||||
Ok(models) => Json(json!(models)).into_response(),
|
Ok(models) => Json(json!(models)).into_response(),
|
||||||
Err(e) => (
|
Err(e) => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(json!({"error": e.to_string()})),
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
}
|
}
|
||||||
@@ -58,11 +69,22 @@ async fn load_model(
|
|||||||
let registry = state.registry.read().await;
|
let registry = state.registry.read().await;
|
||||||
match registry.load_model(&spec).await {
|
match registry.load_model(&spec).await {
|
||||||
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
|
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
|
||||||
Err(e) => (
|
Err(e) => {
|
||||||
StatusCode::BAD_REQUEST,
|
// Log the full anyhow chain server-side so journalctl shows
|
||||||
Json(json!({"error": e.to_string()})),
|
// the underlying failure (hf-hub timeout, permission denied,
|
||||||
)
|
// disk full, etc.) without needing to inspect the HTTP
|
||||||
.into_response(),
|
// response body separately.
|
||||||
|
tracing::warn!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
"load_model failed"
|
||||||
|
);
|
||||||
|
(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,7 +106,11 @@ async fn unload_model(
|
|||||||
let registry = state.registry.read().await;
|
let registry = state.registry.read().await;
|
||||||
match registry.unload_model(&model_id).await {
|
match registry.unload_model(&model_id).await {
|
||||||
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
|
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
|
||||||
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
|
Err(e) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,3 +128,61 @@ async fn model_endpoint(
|
|||||||
.into_response(),
|
.into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// OpenAI-compatible chat completions. Dispatches to streaming SSE when
|
||||||
|
/// `stream: true` is set on the request; otherwise returns a single
|
||||||
|
/// `ChatCompletionResponse`.
|
||||||
|
async fn chat_completions(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Json(req): Json<ChatCompletionRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||||
|
return (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({"error": "candle harness not enabled on this neuron"})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
};
|
||||||
|
|
||||||
|
if req.stream.unwrap_or(false) {
|
||||||
|
match candle.chat_completion_stream(req).await {
|
||||||
|
Ok(rx) => {
|
||||||
|
// Each chunk → one SSE `data: {json}` line. After the
|
||||||
|
// channel closes, append the OpenAI [DONE] terminator.
|
||||||
|
let body_stream = ReceiverStream::new(rx).map(|chunk| {
|
||||||
|
let body = serde_json::to_string(&chunk).unwrap_or_default();
|
||||||
|
Ok::<_, Infallible>(Event::default().data(body))
|
||||||
|
});
|
||||||
|
let done_stream =
|
||||||
|
stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
|
||||||
|
Sse::new(body_stream.chain(done_stream))
|
||||||
|
.keep_alive(KeepAlive::default())
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::Other(e)) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match candle.chat_completion(req).await {
|
||||||
|
Ok(resp) => Json(resp).into_response(),
|
||||||
|
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::Other(e)) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
//! Neuron configuration loaded from neuron.toml.
|
//! Neuron configuration loaded from neuron.toml.
|
||||||
|
|
||||||
use cortex_core::harness::HarnessConfig;
|
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||||
use figment::{
|
use figment::{
|
||||||
Figment,
|
Figment,
|
||||||
providers::{Env, Format, Toml},
|
providers::{Env, Format, Toml},
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct NeuronConfig {
|
pub struct NeuronConfig {
|
||||||
@@ -14,10 +14,35 @@ pub struct NeuronConfig {
|
|||||||
pub port: u16,
|
pub port: u16,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub harnesses: Vec<HarnessConfig>,
|
pub harnesses: Vec<HarnessConfig>,
|
||||||
|
/// Per-harness configuration. Currently only `candle` is recognised.
|
||||||
|
#[serde(default)]
|
||||||
|
pub harness: HarnessSettings,
|
||||||
|
/// Models to auto-load when the neuron service activates. Each entry
|
||||||
|
/// is loaded sequentially before the HTTP listener binds. A failure
|
||||||
|
/// on any single entry logs a warning and proceeds — broken entries
|
||||||
|
/// don't prevent the rest of the fleet from starting.
|
||||||
|
#[serde(default)]
|
||||||
|
pub default_models: Vec<ModelSpec>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Settings for individual harness implementations. Each harness owns
|
||||||
|
/// its own sub-table so users only configure the harnesses they enable.
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
pub struct HarnessSettings {
|
||||||
|
#[serde(default)]
|
||||||
|
pub candle: CandleHarnessConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
pub struct CandleHarnessConfig {
|
||||||
|
/// HuggingFace cache directory for model weights.
|
||||||
|
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
|
||||||
|
#[serde(default)]
|
||||||
|
pub hf_cache: Option<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_port() -> u16 {
|
fn default_port() -> u16 {
|
||||||
9090
|
13131
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NeuronConfig {
|
impl NeuronConfig {
|
||||||
@@ -33,8 +58,10 @@ impl NeuronConfig {
|
|||||||
impl Default for NeuronConfig {
|
impl Default for NeuronConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
port: 9090,
|
port: 13131,
|
||||||
harnesses: vec![],
|
harnesses: vec![],
|
||||||
|
harness: HarnessSettings::default(),
|
||||||
|
default_models: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
910
crates/neuron/src/harness/candle.rs
Normal file
910
crates/neuron/src/harness/candle.rs
Normal file
@@ -0,0 +1,910 @@
|
|||||||
|
//! Candle harness — in-process inference using huggingface/candle.
|
||||||
|
//!
|
||||||
|
//! This is the sole `Harness` implementation. Inference runs inside
|
||||||
|
//! the neuron process; there is no external subprocess.
|
||||||
|
//!
|
||||||
|
//! - Stage 2 wired GGUF (Qwen3 only) load/unload via `quantized_qwen3`.
|
||||||
|
//! - Stage 3 (this) adds `chat_completion` — a non-streaming OpenAI
|
||||||
|
//! compatible chat completion routed to the loaded model's forward
|
||||||
|
//! pass on a per-model serialised generation loop.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use candle_core::quantized::gguf_file;
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
||||||
|
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||||
|
use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec};
|
||||||
|
use cortex_core::openai::{
|
||||||
|
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse,
|
||||||
|
ChatMessage, ChunkChoice, MessageContent, Usage,
|
||||||
|
};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::sync::{Mutex, RwLock, mpsc};
|
||||||
|
|
||||||
|
/// In-process candle harness. Owns the loaded model registry.
|
||||||
|
pub struct CandleHarness {
|
||||||
|
models: Arc<RwLock<HashMap<String, Arc<LoadedModel>>>>,
|
||||||
|
hf_cache: Option<PathBuf>,
|
||||||
|
bind_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A loaded model with its tokenizer, device placement, and architecture-
|
||||||
|
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
|
||||||
|
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
||||||
|
pub struct LoadedModel {
|
||||||
|
pub model_id: String,
|
||||||
|
pub arch: Arc<Mutex<ModelArch>>,
|
||||||
|
pub tokenizer: Tokenizer,
|
||||||
|
pub device: Device,
|
||||||
|
pub quant: Option<String>,
|
||||||
|
pub devices: Vec<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Architecture-specific weights.
|
||||||
|
///
|
||||||
|
/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only;
|
||||||
|
/// TP sharding pre-quantized super-blocks is intractable. Stays the
|
||||||
|
/// default for small models loaded via `Qwen/Qwen3-*-GGUF` and
|
||||||
|
/// `unsloth/Qwen3-*-GGUF` repos.
|
||||||
|
/// - `Qwen3Dense` — bf16 safetensors source. The path that supports
|
||||||
|
/// TP (Stage 7b-ii+) because slicing dense weights by row/column
|
||||||
|
/// under safetensors is mechanical. Used when `ModelSpec.quant` is
|
||||||
|
/// None; intended target for Qwen3.6-27B etc.
|
||||||
|
///
|
||||||
|
/// Stage 8 broadens this to additional families.
|
||||||
|
pub enum ModelArch {
|
||||||
|
Qwen3Quantized(QuantizedQwen3Weights),
|
||||||
|
Qwen3Dense(qwen3_dense::ModelForCausalLM),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Repetition penalty applied to recently-generated tokens before
|
||||||
|
/// sampling. 1.0 disables it; >1.0 makes recently-emitted tokens less
|
||||||
|
/// likely. mistral.rs and llama.cpp default to 1.1, which is enough to
|
||||||
|
/// stop small quantized models from degenerating into "Wait, no, no..."
|
||||||
|
/// loops without distorting normal output.
|
||||||
|
const REPEAT_PENALTY: f32 = 1.1;
|
||||||
|
|
||||||
|
/// Number of recently-generated tokens to feed into the repetition
|
||||||
|
/// penalty. Matches the candle quantized-qwen3 example default.
|
||||||
|
const REPEAT_LAST_N: usize = 64;
|
||||||
|
|
||||||
|
/// Apply the repetition penalty (if any) to the prediction logits and
|
||||||
|
/// then sample. Centralises the prefill / generation-loop call sites
|
||||||
|
/// so they share identical sampling behaviour.
|
||||||
|
fn sample_with_penalty(
|
||||||
|
logits: &Tensor,
|
||||||
|
history: &[u32],
|
||||||
|
logits_processor: &mut LogitsProcessor,
|
||||||
|
) -> Result<u32> {
|
||||||
|
let penalised = if (REPEAT_PENALTY - 1.0).abs() < f32::EPSILON || history.is_empty() {
|
||||||
|
logits.clone()
|
||||||
|
} else {
|
||||||
|
let start = history.len().saturating_sub(REPEAT_LAST_N);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(logits, REPEAT_PENALTY, &history[start..])?
|
||||||
|
};
|
||||||
|
Ok(logits_processor.sample(&penalised)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CandleHarness {
|
||||||
|
pub fn new(bind_url: String, hf_cache: Option<PathBuf>) -> Self {
|
||||||
|
Self {
|
||||||
|
models: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
hf_cache,
|
||||||
|
bind_url,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pick a candle `Device` for the requested indices. Without the
|
||||||
|
/// `cuda` feature, or if CUDA initialisation fails, falls back to CPU.
|
||||||
|
fn pick_device(devices: &[u32]) -> Result<Device> {
|
||||||
|
let _idx = devices.first().copied().unwrap_or(0) as usize;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
match Device::new_cuda(_idx) {
|
||||||
|
Ok(d) => return Ok(d),
|
||||||
|
Err(e) => tracing::warn!(
|
||||||
|
device = _idx,
|
||||||
|
error = %e,
|
||||||
|
"CUDA device unavailable, falling back to CPU"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Device::Cpu)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an hf-hub API client pre-configured with the harness's
|
||||||
|
/// `hf_cache` (when one is set).
|
||||||
|
fn hf_api(&self) -> Result<hf_hub::api::tokio::Api> {
|
||||||
|
let mut builder = hf_hub::api::tokio::ApiBuilder::new();
|
||||||
|
if let Some(cache) = &self.hf_cache {
|
||||||
|
builder = builder.with_cache_dir(cache.clone());
|
||||||
|
}
|
||||||
|
builder.build().context("build hf-hub API")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve a dense (bf16/fp16 safetensors) model to its local file
|
||||||
|
/// paths.
|
||||||
|
///
|
||||||
|
/// Handles both sharded repos (`model.safetensors.index.json` plus
|
||||||
|
/// several `model-*.safetensors`) and the single-file layout
|
||||||
|
/// (`model.safetensors`). Returns the safetensors paths in
|
||||||
|
/// arbitrary order — `VarBuilder` unifies them into one tensor view.
|
||||||
|
async fn resolve_dense_files(
|
||||||
|
&self,
|
||||||
|
spec: &ModelSpec,
|
||||||
|
) -> Result<(PathBuf, PathBuf, Vec<PathBuf>)> {
|
||||||
|
let api = self.hf_api()?;
|
||||||
|
let repo = api.model(spec.model_id.clone());
|
||||||
|
|
||||||
|
let config_path = repo
|
||||||
|
.get("config.json")
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("fetch config.json from {}", spec.model_id))?;
|
||||||
|
let tokenizer_path = repo
|
||||||
|
.get("tokenizer.json")
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("fetch tokenizer.json from {}", spec.model_id))?;
|
||||||
|
|
||||||
|
// Prefer the sharded layout (most HF dense models > 5B ship it).
|
||||||
|
let safetensors_paths = match repo.get("model.safetensors.index.json").await {
|
||||||
|
Ok(index_path) => {
|
||||||
|
let index_text = std::fs::read_to_string(&index_path)
|
||||||
|
.context("read model.safetensors.index.json")?;
|
||||||
|
let index: serde_json::Value = serde_json::from_str(&index_text)
|
||||||
|
.context("parse model.safetensors.index.json")?;
|
||||||
|
let weight_map = index
|
||||||
|
.get("weight_map")
|
||||||
|
.and_then(|v| v.as_object())
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("safetensors index missing weight_map object")
|
||||||
|
})?;
|
||||||
|
let unique: std::collections::BTreeSet<String> = weight_map
|
||||||
|
.values()
|
||||||
|
.filter_map(|v| v.as_str().map(String::from))
|
||||||
|
.collect();
|
||||||
|
let mut paths = Vec::with_capacity(unique.len());
|
||||||
|
for fname in unique {
|
||||||
|
let p = repo
|
||||||
|
.get(&fname)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("fetch sharded safetensors {fname}"))?;
|
||||||
|
paths.push(p);
|
||||||
|
}
|
||||||
|
paths
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Single-file fallback.
|
||||||
|
let p = repo
|
||||||
|
.get("model.safetensors")
|
||||||
|
.await
|
||||||
|
.context("fetch model.safetensors (single-file layout)")?;
|
||||||
|
vec![p]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok((config_path, tokenizer_path, safetensors_paths))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve + load a GGUF (pre-quantized) Qwen3. Returns the
|
||||||
|
/// tokenizer.json path so the caller can construct the Tokenizer
|
||||||
|
/// uniformly across source formats.
|
||||||
|
async fn load_arch_gguf(
|
||||||
|
&self,
|
||||||
|
spec: &ModelSpec,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<(PathBuf, ModelArch)> {
|
||||||
|
let (gguf_path, tokenizer_path) = self.resolve_files(spec).await?;
|
||||||
|
let device_for_load = device.clone();
|
||||||
|
let gguf_path_for_load = gguf_path.clone();
|
||||||
|
let model_id_for_log = spec.model_id.clone();
|
||||||
|
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
|
||||||
|
tracing::info!(model = %model_id_for_log, path = ?gguf_path_for_load, "loading GGUF");
|
||||||
|
let mut file = std::fs::File::open(&gguf_path_for_load).context("open GGUF file")?;
|
||||||
|
let content = gguf_file::Content::read(&mut file)
|
||||||
|
.map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
|
||||||
|
|
||||||
|
let architecture = content
|
||||||
|
.metadata
|
||||||
|
.get("general.architecture")
|
||||||
|
.and_then(|v| v.to_string().ok().cloned())
|
||||||
|
.unwrap_or_default();
|
||||||
|
tracing::info!(architecture = %architecture, "GGUF architecture");
|
||||||
|
|
||||||
|
match architecture.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let weights =
|
||||||
|
QuantizedQwen3Weights::from_gguf(content, &mut file, &device_for_load)
|
||||||
|
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3Quantized(weights))
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unsupported GGUF architecture '{other}'; quantized path only supports qwen3"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context("blocking GGUF load task panicked")??;
|
||||||
|
Ok((tokenizer_path, arch))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve + load a dense Qwen3 from safetensors. Uses
|
||||||
|
/// `candle-transformers::models::qwen3::ModelForCausalLM` and
|
||||||
|
/// builds a VarBuilder over the mmap'd safetensors files. dtype
|
||||||
|
/// is bf16 by default to match the HF distribution dtype for
|
||||||
|
/// recent Qwen3 family models; fall back to f16 if the device
|
||||||
|
/// doesn't support bf16.
|
||||||
|
async fn load_arch_dense(
|
||||||
|
&self,
|
||||||
|
spec: &ModelSpec,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<(PathBuf, ModelArch)> {
|
||||||
|
let (config_path, tokenizer_path, safetensors_paths) =
|
||||||
|
self.resolve_dense_files(spec).await?;
|
||||||
|
let device_for_load = device.clone();
|
||||||
|
let model_id_for_log = spec.model_id.clone();
|
||||||
|
|
||||||
|
let arch = tokio::task::spawn_blocking(move || -> Result<ModelArch> {
|
||||||
|
tracing::info!(
|
||||||
|
model = %model_id_for_log,
|
||||||
|
shards = safetensors_paths.len(),
|
||||||
|
"loading dense Qwen3 from safetensors"
|
||||||
|
);
|
||||||
|
let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?;
|
||||||
|
let cfg: qwen3_dense::Config =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
||||||
|
|
||||||
|
// bf16 is the canonical Qwen3 distribution dtype. CUDA
|
||||||
|
// devices on Ada+ support it; Ampere also supports bf16
|
||||||
|
// natively. CPU candle handles bf16 via emulation.
|
||||||
|
let dtype = DType::BF16;
|
||||||
|
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
|
||||||
|
// mutation of the underlying files by another process while
|
||||||
|
// we hold the mapping is UB. We trust that nothing else on
|
||||||
|
// the host modifies the HF cache files during a model's
|
||||||
|
// lifetime (hf-hub itself is immutable-by-design).
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load)
|
||||||
|
.context("build VarBuilder over safetensors")?
|
||||||
|
};
|
||||||
|
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
|
||||||
|
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3Dense(model))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context("blocking dense load task panicked")??;
|
||||||
|
Ok((tokenizer_path, arch))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve a model spec to local GGUF and tokenizer file paths via
|
||||||
|
/// hf-hub. Downloads on first use; subsequent calls are cached.
|
||||||
|
async fn resolve_files(&self, spec: &ModelSpec) -> Result<(PathBuf, PathBuf)> {
|
||||||
|
let api = self.hf_api()?;
|
||||||
|
let repo = api.model(spec.model_id.clone());
|
||||||
|
|
||||||
|
let info = repo
|
||||||
|
.info()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("fetch HF repo info for {}", spec.model_id))?;
|
||||||
|
|
||||||
|
let quant = spec.quant.as_deref().unwrap_or("");
|
||||||
|
let quant_lc = quant.to_lowercase();
|
||||||
|
let gguf_filename = info
|
||||||
|
.siblings
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.rfilename.as_str())
|
||||||
|
.filter(|name| name.to_lowercase().ends_with(".gguf"))
|
||||||
|
.find(|name| quant_lc.is_empty() || name.to_lowercase().contains(&quant_lc))
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"no GGUF file matching quant {:?} in repo {}",
|
||||||
|
spec.quant,
|
||||||
|
spec.model_id
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
file = %gguf_filename,
|
||||||
|
"resolving GGUF (may be cached)"
|
||||||
|
);
|
||||||
|
let gguf_path = repo
|
||||||
|
.get(&gguf_filename)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("fetch GGUF {gguf_filename}"))?;
|
||||||
|
|
||||||
|
// GGUF-only HF repos (unsloth/Qwen3-*-GGUF, Qwen/Qwen3-*-GGUF,
|
||||||
|
// etc.) ship the .gguf file but not tokenizer.json — the
|
||||||
|
// tokenizer.json lives in the base non-GGUF repo. Derive the
|
||||||
|
// base repo id by stripping a `-GGUF` / `-gguf` suffix; if
|
||||||
|
// there's no such suffix the same repo is used (works for
|
||||||
|
// non-GGUF model_ids).
|
||||||
|
let tokenizer_repo_id = spec
|
||||||
|
.model_id
|
||||||
|
.strip_suffix("-GGUF")
|
||||||
|
.or_else(|| spec.model_id.strip_suffix("-gguf"))
|
||||||
|
.unwrap_or(spec.model_id.as_str())
|
||||||
|
.to_string();
|
||||||
|
let tokenizer_repo = if tokenizer_repo_id == spec.model_id {
|
||||||
|
repo
|
||||||
|
} else {
|
||||||
|
tracing::debug!(
|
||||||
|
from = %spec.model_id,
|
||||||
|
to = %tokenizer_repo_id,
|
||||||
|
"tokenizer.json sourced from base repo (GGUF suffix stripped)"
|
||||||
|
);
|
||||||
|
api.model(tokenizer_repo_id.clone())
|
||||||
|
};
|
||||||
|
let tokenizer_path = tokenizer_repo
|
||||||
|
.get("tokenizer.json")
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("fetch tokenizer.json from {tokenizer_repo_id}"))?;
|
||||||
|
Ok((gguf_path, tokenizer_path))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run a non-streaming chat completion against a loaded model.
|
||||||
|
///
|
||||||
|
/// Returns a typed `InferenceError` when the model isn't loaded so the
|
||||||
|
/// handler can map to an appropriate HTTP status without string-matching.
|
||||||
|
pub async fn chat_completion(
|
||||||
|
&self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||||
|
let loaded = {
|
||||||
|
let models = self.models.read().await;
|
||||||
|
models.get(&request.model).cloned()
|
||||||
|
};
|
||||||
|
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||||
|
|
||||||
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
|
|
||||||
|
let encoding = loaded
|
||||||
|
.tokenizer
|
||||||
|
.encode(prompt.as_str(), true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||||
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||||
|
let prompt_len = prompt_tokens.len();
|
||||||
|
|
||||||
|
let temperature = request.temperature.unwrap_or(0.7);
|
||||||
|
let top_p = request.top_p;
|
||||||
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
||||||
|
let seed = unix_subsec_nanos();
|
||||||
|
|
||||||
|
let eos_id = loaded
|
||||||
|
.tokenizer
|
||||||
|
.token_to_id("<|im_end|>")
|
||||||
|
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
||||||
|
|
||||||
|
let arch_arc = Arc::clone(&loaded.arch);
|
||||||
|
let device = loaded.device.clone();
|
||||||
|
let model_id = request.model.clone();
|
||||||
|
|
||||||
|
let (generated_ids, finish_reason) =
|
||||||
|
tokio::task::spawn_blocking(move || -> Result<(Vec<u32>, String)> {
|
||||||
|
let mut guard = arch_arc.blocking_lock();
|
||||||
|
run_inference(
|
||||||
|
&mut guard,
|
||||||
|
&device,
|
||||||
|
&prompt_tokens,
|
||||||
|
max_new,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
seed,
|
||||||
|
eos_id,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("inference task panicked: {e}")))?
|
||||||
|
.map_err(InferenceError::Other)?;
|
||||||
|
|
||||||
|
let completion_text = loaded
|
||||||
|
.tokenizer
|
||||||
|
.decode(&generated_ids, true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
||||||
|
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: prompt_len as u64,
|
||||||
|
completion_tokens: generated_ids.len() as u64,
|
||||||
|
total_tokens: (prompt_len + generated_ids.len()) as u64,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ChatCompletionResponse {
|
||||||
|
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
||||||
|
object: "chat.completion".into(),
|
||||||
|
created: unix_now_secs(),
|
||||||
|
model: model_id,
|
||||||
|
choices: vec![ChatCompletionChoice {
|
||||||
|
index: 0,
|
||||||
|
message: ChatMessage {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: MessageContent::Text(completion_text),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
},
|
||||||
|
finish_reason: Some(finish_reason),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: Some(usage),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run a streaming chat completion against a loaded model.
|
||||||
|
///
|
||||||
|
/// Returns an `mpsc::Receiver` that yields `ChatCompletionChunk`s in
|
||||||
|
/// OpenAI SSE format. The first chunk carries the assistant role;
|
||||||
|
/// subsequent chunks carry incremental `content` deltas; the final
|
||||||
|
/// chunk carries `finish_reason`. The handler is responsible for
|
||||||
|
/// wrapping these into an SSE response and appending the `[DONE]`
|
||||||
|
/// terminator.
|
||||||
|
///
|
||||||
|
/// Token-by-token decoding tracks the cumulative decoded prefix so
|
||||||
|
/// BPE byte-fallback boundaries don't split a UTF-8 char across
|
||||||
|
/// chunks.
|
||||||
|
pub async fn chat_completion_stream(
|
||||||
|
&self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||||||
|
let loaded = {
|
||||||
|
let models = self.models.read().await;
|
||||||
|
models.get(&request.model).cloned()
|
||||||
|
};
|
||||||
|
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||||
|
|
||||||
|
let prompt = format_qwen3_prompt(&request.messages);
|
||||||
|
let encoding = loaded
|
||||||
|
.tokenizer
|
||||||
|
.encode(prompt.as_str(), true)
|
||||||
|
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||||
|
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||||
|
|
||||||
|
let temperature = request.temperature.unwrap_or(0.7);
|
||||||
|
let top_p = request.top_p;
|
||||||
|
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
||||||
|
let seed = unix_subsec_nanos();
|
||||||
|
|
||||||
|
let eos_id = loaded
|
||||||
|
.tokenizer
|
||||||
|
.token_to_id("<|im_end|>")
|
||||||
|
.or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"));
|
||||||
|
|
||||||
|
let arch_arc = Arc::clone(&loaded.arch);
|
||||||
|
let device = loaded.device.clone();
|
||||||
|
let tokenizer = loaded.tokenizer.clone();
|
||||||
|
let model_id = request.model.clone();
|
||||||
|
let id = format!("chatcmpl-{:x}", unix_subsec_nanos());
|
||||||
|
let created = unix_now_secs();
|
||||||
|
|
||||||
|
// Bounded channel so the producer (blocking inference) is back-
|
||||||
|
// pressured by the consumer (SSE writer). 32 is generous —
|
||||||
|
// tokens arrive one at a time and the SSE writer is async.
|
||||||
|
let (tx, rx) = mpsc::channel::<ChatCompletionChunk>(32);
|
||||||
|
|
||||||
|
// Lead chunk: announce the assistant role per OpenAI streaming
|
||||||
|
// conventions. Tools that auto-detect a streaming reply expect
|
||||||
|
// this before any content delta.
|
||||||
|
let role_chunk = ChatCompletionChunk {
|
||||||
|
id: id.clone(),
|
||||||
|
object: "chat.completion.chunk".into(),
|
||||||
|
created,
|
||||||
|
model: model_id.clone(),
|
||||||
|
choices: vec![ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: json!({"role": "assistant"}),
|
||||||
|
finish_reason: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
};
|
||||||
|
// If sending the role chunk fails the receiver is already gone;
|
||||||
|
// bail before kicking off the heavy blocking work.
|
||||||
|
tx.send(role_chunk)
|
||||||
|
.await
|
||||||
|
.map_err(|_| InferenceError::Other(anyhow::anyhow!("client disconnected")))?;
|
||||||
|
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
let mut guard = arch_arc.blocking_lock();
|
||||||
|
if let Err(e) = run_inference_streaming(
|
||||||
|
&mut guard,
|
||||||
|
&device,
|
||||||
|
&tokenizer,
|
||||||
|
&prompt_tokens,
|
||||||
|
max_new,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
seed,
|
||||||
|
eos_id,
|
||||||
|
&id,
|
||||||
|
created,
|
||||||
|
&model_id,
|
||||||
|
&tx,
|
||||||
|
) {
|
||||||
|
tracing::warn!(model = %model_id, error = %e, "streaming inference failed");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Harness for CandleHarness {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"candle"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self) -> HarnessHealth {
|
||||||
|
HarnessHealth {
|
||||||
|
name: "candle".into(),
|
||||||
|
running: true,
|
||||||
|
uptime_secs: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
|
let models = self.models.read().await;
|
||||||
|
Ok(models
|
||||||
|
.values()
|
||||||
|
.map(|m| ModelInfo {
|
||||||
|
id: m.model_id.clone(),
|
||||||
|
harness: "candle".into(),
|
||||||
|
status: "loaded".into(),
|
||||||
|
devices: m.devices.clone(),
|
||||||
|
vram_used_mb: None,
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
||||||
|
if spec.harness != "candle" {
|
||||||
|
anyhow::bail!("expected harness=candle, got harness={}", spec.harness);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let models = self.models.read().await;
|
||||||
|
if models.contains_key(&spec.model_id) {
|
||||||
|
anyhow::bail!("model '{}' already loaded", spec.model_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stage 7a-i scaffolds tensor-parallel worker subprocesses but
|
||||||
|
// does not yet route inference through them. Refuse TP loads
|
||||||
|
// for now with a clear marker so the request surface is honest;
|
||||||
|
// Stage 7b-iv replaces this bail with the TP dispatch.
|
||||||
|
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||||
|
if tp_size > 1 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"tensor_parallel={tp_size} requested for '{}': TP worker \
|
||||||
|
lifecycle + NCCL handshake are in place (Stage 7a) but \
|
||||||
|
TP-aware Qwen3 inference orchestration lands in Stage \
|
||||||
|
7b-iv; single-GPU loads only for now",
|
||||||
|
spec.model_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
||||||
|
let device = Self::pick_device(&devices)?;
|
||||||
|
|
||||||
|
// Dispatch by source format: GGUF (pre-quantized, single-GPU
|
||||||
|
// only path) vs safetensors dense (bf16/fp16; the path that
|
||||||
|
// grows TP support). `spec.quant` is the signal — Some means
|
||||||
|
// the operator picked a quantized GGUF; None means dense.
|
||||||
|
let (tokenizer_path, arch) = if spec.quant.is_some() {
|
||||||
|
self.load_arch_gguf(spec, &device).await?
|
||||||
|
} else {
|
||||||
|
self.load_arch_dense(spec, &device).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||||
|
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||||
|
|
||||||
|
let loaded = Arc::new(LoadedModel {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
arch: Arc::new(Mutex::new(arch)),
|
||||||
|
tokenizer,
|
||||||
|
device,
|
||||||
|
quant: spec.quant.clone(),
|
||||||
|
devices,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut models = self.models.write().await;
|
||||||
|
models.insert(spec.model_id.clone(), loaded);
|
||||||
|
tracing::info!(model = %spec.model_id, "model loaded");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||||||
|
let mut models = self.models.write().await;
|
||||||
|
if models.remove(model_id).is_none() {
|
||||||
|
anyhow::bail!("model '{model_id}' not loaded");
|
||||||
|
}
|
||||||
|
tracing::info!(model = %model_id, "model unloaded");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn inference_endpoint(&self, model_id: &str) -> Option<String> {
|
||||||
|
let models = self.models.read().await;
|
||||||
|
models.contains_key(model_id).then(|| self.bind_url.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors returned by `CandleHarness::chat_completion`. The
|
||||||
|
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
|
||||||
|
/// without string-matching on anyhow messages.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum InferenceError {
|
||||||
|
#[error("model '{0}' not loaded on this neuron")]
|
||||||
|
ModelNotLoaded(String),
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply the Qwen3 chat template:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// <|im_start|>{role}\n{content}<|im_end|>\n
|
||||||
|
/// ...
|
||||||
|
/// <|im_start|>assistant\n
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// The trailing `<|im_start|>assistant\n` cues the model to begin a turn.
|
||||||
|
/// Non-text content parts (vision blocks) are joined as text only; full
|
||||||
|
/// multimodal handling is out of scope for Stage 3.
|
||||||
|
fn format_qwen3_prompt(messages: &[ChatMessage]) -> String {
|
||||||
|
let mut prompt = String::new();
|
||||||
|
for msg in messages {
|
||||||
|
let content = match &msg.content {
|
||||||
|
MessageContent::Text(s) => s.clone(),
|
||||||
|
MessageContent::Parts(parts) => parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|p| p.get("text").and_then(|v| v.as_str()))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(""),
|
||||||
|
};
|
||||||
|
prompt.push_str("<|im_start|>");
|
||||||
|
prompt.push_str(&msg.role);
|
||||||
|
prompt.push('\n');
|
||||||
|
prompt.push_str(&content);
|
||||||
|
prompt.push_str("<|im_end|>\n");
|
||||||
|
}
|
||||||
|
prompt.push_str("<|im_start|>assistant\n");
|
||||||
|
prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_inference(
|
||||||
|
arch: &mut ModelArch,
|
||||||
|
device: &Device,
|
||||||
|
prompt_tokens: &[u32],
|
||||||
|
max_new: usize,
|
||||||
|
temperature: f64,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
seed: u64,
|
||||||
|
eos_id: Option<u32>,
|
||||||
|
) -> Result<(Vec<u32>, String)> {
|
||||||
|
let mut logits_processor = {
|
||||||
|
let sampling = if temperature <= 0.0 {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match top_p {
|
||||||
|
Some(p) => Sampling::TopP { p, temperature },
|
||||||
|
None => Sampling::All { temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut generated: Vec<u32> = Vec::new();
|
||||||
|
|
||||||
|
let mut next_token = match arch {
|
||||||
|
ModelArch::Qwen3Quantized(model) => {
|
||||||
|
model.clear_kv_cache();
|
||||||
|
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, 0)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
||||||
|
}
|
||||||
|
ModelArch::Qwen3Dense(model) => {
|
||||||
|
model.clear_kv_cache();
|
||||||
|
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||||
|
// qwen3::ModelForCausalLM::forward returns [B, 1, V] —
|
||||||
|
// no final squeeze on the dense path, unlike the quantized
|
||||||
|
// variant which returns [B, V]. Strip both batch and seq
|
||||||
|
// dims to get the rank-1 logits LogitsProcessor expects.
|
||||||
|
let logits = model.forward(&input, 0)?.squeeze(0)?.squeeze(0)?;
|
||||||
|
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
return Ok((generated, "stop".into()));
|
||||||
|
}
|
||||||
|
generated.push(next_token);
|
||||||
|
|
||||||
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
|
next_token = match arch {
|
||||||
|
ModelArch::Qwen3Quantized(model) => {
|
||||||
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
||||||
|
}
|
||||||
|
ModelArch::Qwen3Dense(model) => {
|
||||||
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||||
|
// Dense returns [B, 1, V]; strip both leading dims.
|
||||||
|
let logits = model
|
||||||
|
.forward(&input, prompt_tokens.len() + index)?
|
||||||
|
.squeeze(0)?
|
||||||
|
.squeeze(0)?;
|
||||||
|
sample_with_penalty(&logits, &generated, &mut logits_processor)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
return Ok((generated, "stop".into()));
|
||||||
|
}
|
||||||
|
generated.push(next_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((generated, "length".into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Streaming counterpart to `run_inference`. Emits chunks via `tx` as
|
||||||
|
/// tokens are generated and exits on EOS, max_new, or receiver drop.
|
||||||
|
///
|
||||||
|
/// Detokenization tracks the cumulative decoded prefix so each chunk's
|
||||||
|
/// `content` delta is the substring appended since the last chunk —
|
||||||
|
/// safe across BPE byte-fallback boundaries.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_inference_streaming(
|
||||||
|
arch: &mut ModelArch,
|
||||||
|
device: &Device,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
prompt_tokens: &[u32],
|
||||||
|
max_new: usize,
|
||||||
|
temperature: f64,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
seed: u64,
|
||||||
|
eos_id: Option<u32>,
|
||||||
|
id: &str,
|
||||||
|
created: u64,
|
||||||
|
model_id: &str,
|
||||||
|
tx: &mpsc::Sender<ChatCompletionChunk>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut logits_processor = {
|
||||||
|
let sampling = if temperature <= 0.0 {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match top_p {
|
||||||
|
Some(p) => Sampling::TopP { p, temperature },
|
||||||
|
None => Sampling::All { temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut all_tokens: Vec<u32> = Vec::new();
|
||||||
|
let mut decoded_prefix = String::new();
|
||||||
|
let mut finish_reason = "length".to_string();
|
||||||
|
|
||||||
|
let mut next_token = match arch {
|
||||||
|
ModelArch::Qwen3Quantized(model) => {
|
||||||
|
model.clear_kv_cache();
|
||||||
|
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, 0)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
||||||
|
}
|
||||||
|
ModelArch::Qwen3Dense(model) => {
|
||||||
|
model.clear_kv_cache();
|
||||||
|
let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, 0)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result<bool> {
|
||||||
|
let full = tokenizer
|
||||||
|
.decode(all_tokens, true)
|
||||||
|
.map_err(|e| anyhow::anyhow!("decode: {e}"))?;
|
||||||
|
if full.len() > decoded_prefix.len() {
|
||||||
|
let delta = full[decoded_prefix.len()..].to_string();
|
||||||
|
*decoded_prefix = full;
|
||||||
|
let chunk = ChatCompletionChunk {
|
||||||
|
id: id.into(),
|
||||||
|
object: "chat.completion.chunk".into(),
|
||||||
|
created,
|
||||||
|
model: model_id.into(),
|
||||||
|
choices: vec![ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: json!({ "content": delta }),
|
||||||
|
finish_reason: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
};
|
||||||
|
// blocking_send returns Err if the consumer hung up — signal
|
||||||
|
// the caller to stop generating.
|
||||||
|
if tx.blocking_send(chunk).is_err() {
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(true)
|
||||||
|
};
|
||||||
|
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
} else {
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if !emit_token(&all_tokens, &mut decoded_prefix)? {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
for index in 0..max_new.saturating_sub(1) {
|
||||||
|
next_token = match arch {
|
||||||
|
ModelArch::Qwen3Quantized(model) => {
|
||||||
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
||||||
|
}
|
||||||
|
ModelArch::Qwen3Dense(model) => {
|
||||||
|
let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?;
|
||||||
|
// Dense returns [B, 1, V]; strip both leading dims.
|
||||||
|
let logits = model
|
||||||
|
.forward(&input, prompt_tokens.len() + index)?
|
||||||
|
.squeeze(0)?
|
||||||
|
.squeeze(0)?;
|
||||||
|
sample_with_penalty(&logits, &all_tokens, &mut logits_processor)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if Some(next_token) == eos_id {
|
||||||
|
finish_reason = "stop".into();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if !emit_token(&all_tokens, &mut decoded_prefix)? {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_chunk = ChatCompletionChunk {
|
||||||
|
id: id.into(),
|
||||||
|
object: "chat.completion.chunk".into(),
|
||||||
|
created,
|
||||||
|
model: model_id.into(),
|
||||||
|
choices: vec![ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta: serde_json::Value::Object(Default::default()),
|
||||||
|
finish_reason: Some(finish_reason),
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
extra: serde_json::Value::Object(Default::default()),
|
||||||
|
};
|
||||||
|
let _ = tx.blocking_send(final_chunk);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unix_now_secs() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_secs())
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unix_subsec_nanos() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_nanos() as u64)
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
@@ -1 +0,0 @@
|
|||||||
// llama.cpp harness implementation — Phase 11.
|
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
//! mistral.rs harness implementation.
|
|
||||||
//!
|
|
||||||
//! Wraps the mistral.rs HTTP API for model lifecycle management
|
|
||||||
//! and optionally manages the process via systemd.
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec};
|
|
||||||
use reqwest::Client;
|
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
pub struct MistralRsHarness {
|
|
||||||
endpoint: String,
|
|
||||||
systemd_unit: Option<String>,
|
|
||||||
client: Client,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MistralRsHarness {
|
|
||||||
pub fn new(endpoint: String, systemd_unit: Option<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
endpoint,
|
|
||||||
systemd_unit,
|
|
||||||
client: Client::builder()
|
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
|
||||||
.build()
|
|
||||||
.expect("failed to build HTTP client"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Response from mistral.rs `GET /v1/models`.
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ModelsResponse {
|
|
||||||
data: Vec<ModelEntry>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ModelEntry {
|
|
||||||
id: String,
|
|
||||||
#[serde(default)]
|
|
||||||
status: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Harness for MistralRsHarness {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"mistralrs"
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
|
||||||
let Some(unit) = &self.systemd_unit else {
|
|
||||||
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
|
||||||
};
|
|
||||||
|
|
||||||
let output = tokio::process::Command::new("systemctl")
|
|
||||||
.args(["start", unit])
|
|
||||||
.output()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !output.status.success() {
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
||||||
anyhow::bail!("systemctl start {unit} failed: {stderr}");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the health endpoint to respond (up to 30s).
|
|
||||||
let url = format!("{}/health", self.endpoint);
|
|
||||||
for _ in 0..30 {
|
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
|
||||||
if self.client.get(&url).send().await.is_ok() {
|
|
||||||
tracing::info!(unit, "mistralrs started and healthy");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
anyhow::bail!("mistralrs started but health endpoint did not respond within 30s");
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn stop(&self) -> Result<()> {
|
|
||||||
let Some(unit) = &self.systemd_unit else {
|
|
||||||
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
|
||||||
};
|
|
||||||
|
|
||||||
let output = tokio::process::Command::new("systemctl")
|
|
||||||
.args(["stop", unit])
|
|
||||||
.output()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !output.status.success() {
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
||||||
anyhow::bail!("systemctl stop {unit} failed: {stderr}");
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health(&self) -> HarnessHealth {
|
|
||||||
let url = format!("{}/health", self.endpoint);
|
|
||||||
let running = self.client.get(&url).send().await.is_ok();
|
|
||||||
HarnessHealth {
|
|
||||||
name: "mistralrs".into(),
|
|
||||||
running,
|
|
||||||
uptime_secs: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
|
||||||
let url = format!("{}/v1/models", self.endpoint);
|
|
||||||
let resp = self.client.get(&url).send().await?;
|
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
anyhow::bail!("GET /v1/models returned {}", resp.status());
|
|
||||||
}
|
|
||||||
|
|
||||||
let models_resp: ModelsResponse = resp.json().await?;
|
|
||||||
Ok(models_resp
|
|
||||||
.data
|
|
||||||
.into_iter()
|
|
||||||
.map(|m| ModelInfo {
|
|
||||||
id: m.id,
|
|
||||||
harness: "mistralrs".into(),
|
|
||||||
status: m.status.unwrap_or_else(|| "loaded".into()),
|
|
||||||
devices: vec![],
|
|
||||||
vram_used_mb: None,
|
|
||||||
})
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
|
||||||
let url = format!("{}/v1/models/reload", self.endpoint);
|
|
||||||
let resp = self
|
|
||||||
.client
|
|
||||||
.post(&url)
|
|
||||||
.json(&serde_json::json!({ "model_id": spec.model_id }))
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
let body = resp.text().await.unwrap_or_default();
|
|
||||||
anyhow::bail!("POST /v1/models/reload failed: {body}");
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
|
||||||
let url = format!("{}/v1/models/unload", self.endpoint);
|
|
||||||
let resp = self
|
|
||||||
.client
|
|
||||||
.post(&url)
|
|
||||||
.json(&serde_json::json!({ "model_id": model_id }))
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
let body = resp.text().await.unwrap_or_default();
|
|
||||||
anyhow::bail!("POST /v1/models/unload failed: {body}");
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn inference_endpoint(&self, _model_id: &str) -> Option<String> {
|
|
||||||
// mistral.rs routes internally by model name in the request body,
|
|
||||||
// so the inference endpoint is always the base URL.
|
|
||||||
Some(self.endpoint.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,15 +1,23 @@
|
|||||||
//! Harness registry — maps harness names to trait implementations.
|
//! Harness registry — maps harness names to trait implementations.
|
||||||
|
|
||||||
pub mod llamacpp;
|
pub mod candle;
|
||||||
pub mod mistralrs;
|
pub mod tp;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Registry of available harness implementations.
|
/// Registry of available harness implementations.
|
||||||
|
///
|
||||||
|
/// Holds an `Arc<dyn Harness>` per harness for generic lifecycle dispatch
|
||||||
|
/// (load/unload/list_models). When a candle harness is registered, a typed
|
||||||
|
/// `Arc<CandleHarness>` is also cached so inference routes can bypass the
|
||||||
|
/// dyn-Trait dispatch and reach harness-specific methods (chat completion,
|
||||||
|
/// streaming, etc.).
|
||||||
pub struct HarnessRegistry {
|
pub struct HarnessRegistry {
|
||||||
harnesses: HashMap<String, Box<dyn Harness>>,
|
harnesses: HashMap<String, Arc<dyn Harness>>,
|
||||||
|
candle: Option<Arc<candle::CandleHarness>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for HarnessRegistry {
|
impl Default for HarnessRegistry {
|
||||||
@@ -22,10 +30,11 @@ impl HarnessRegistry {
|
|||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
harnesses: HashMap::new(),
|
harnesses: HashMap::new(),
|
||||||
|
candle: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register(&mut self, harness: Box<dyn Harness>) {
|
pub fn register(&mut self, harness: Arc<dyn Harness>) {
|
||||||
self.harnesses.insert(harness.name().to_string(), harness);
|
self.harnesses.insert(harness.name().to_string(), harness);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,6 +43,12 @@ impl HarnessRegistry {
|
|||||||
self.harnesses.keys().cloned().collect()
|
self.harnesses.keys().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Typed handle to the candle harness, if registered. Used by inference
|
||||||
|
/// routes that need methods beyond the `Harness` trait surface.
|
||||||
|
pub fn candle(&self) -> Option<Arc<candle::CandleHarness>> {
|
||||||
|
self.candle.clone()
|
||||||
|
}
|
||||||
|
|
||||||
/// List models from all registered harnesses.
|
/// List models from all registered harnesses.
|
||||||
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
let mut all = Vec::new();
|
let mut all = Vec::new();
|
||||||
@@ -81,19 +96,25 @@ impl HarnessRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Build a registry from harness configs.
|
/// Build a registry from harness configs.
|
||||||
pub fn from_configs(configs: &[HarnessConfig]) -> Self {
|
///
|
||||||
|
/// `bind_url` is the URL where this neuron serves inference (its own
|
||||||
|
/// listen address). In-process harnesses (currently the only kind)
|
||||||
|
/// return this URL from `inference_endpoint`.
|
||||||
|
pub fn from_configs(
|
||||||
|
configs: &[HarnessConfig],
|
||||||
|
bind_url: &str,
|
||||||
|
settings: &crate::config::HarnessSettings,
|
||||||
|
) -> Self {
|
||||||
let mut registry = Self::new();
|
let mut registry = Self::new();
|
||||||
for config in configs {
|
for config in configs {
|
||||||
match config.name.as_str() {
|
match config.name.as_str() {
|
||||||
"mistralrs" => {
|
"candle" => {
|
||||||
if let Some(endpoint) = &config.endpoint {
|
let harness = Arc::new(candle::CandleHarness::new(
|
||||||
registry.register(Box::new(mistralrs::MistralRsHarness::new(
|
bind_url.to_string(),
|
||||||
endpoint.clone(),
|
settings.candle.hf_cache.clone(),
|
||||||
config.systemd_unit.clone(),
|
));
|
||||||
)));
|
registry.candle = Some(Arc::clone(&harness));
|
||||||
} else {
|
registry.harnesses.insert("candle".into(), harness);
|
||||||
tracing::warn!("mistralrs harness missing endpoint, skipping");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
other => {
|
other => {
|
||||||
tracing::warn!(harness = other, "unknown harness type, skipping");
|
tracing::warn!(harness = other, "unknown harness type, skipping");
|
||||||
|
|||||||
120
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
120
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
|
||||||
|
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
|
||||||
|
//!
|
||||||
|
//! Ported from the canonical
|
||||||
|
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
|
||||||
|
//! Row-parallel layers apply this op after their local matmul to sum
|
||||||
|
//! partial outputs across NCCL ranks.
|
||||||
|
//!
|
||||||
|
//! Available only under `--features cuda`; on CPU builds this module
|
||||||
|
//! is empty and row-parallel layers degenerate to local matmul only
|
||||||
|
//! (useful for compile-checking the model code; correctness requires
|
||||||
|
//! cuda).
|
||||||
|
//!
|
||||||
|
//! Thread-safety caveat: NCCL communicators are technically only
|
||||||
|
//! safe to use from a single thread at a time
|
||||||
|
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
|
||||||
|
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
|
||||||
|
//! against it from the dedicated `spawn_blocking` thread the inference
|
||||||
|
//! pipeline already uses for candle's forward passes.
|
||||||
|
|
||||||
|
#![cfg(feature = "cuda")]
|
||||||
|
|
||||||
|
use candle_core::backend::BackendStorage;
|
||||||
|
use candle_core::cuda_backend::WrapErr;
|
||||||
|
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
|
||||||
|
use cudarc::nccl::{Comm, ReduceOp};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
|
||||||
|
/// graph as a custom op. Each row-parallel layer holds one of these.
|
||||||
|
pub struct AllReduce {
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
|
||||||
|
// that issuing ops against one comm from multiple threads concurrently
|
||||||
|
// is unsafe. We serialise via the single spawn_blocking thread that
|
||||||
|
// drives the model's forward pass. The Send/Sync impl is necessary
|
||||||
|
// because candle's CustomOp1 trait bounds require it; the correctness
|
||||||
|
// invariant is enforced at the call site, not the type level.
|
||||||
|
unsafe impl Send for AllReduce {}
|
||||||
|
unsafe impl Sync for AllReduce {}
|
||||||
|
|
||||||
|
impl AllReduce {
|
||||||
|
pub fn new(comm: Arc<Comm>) -> Self {
|
||||||
|
Self { comm }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn comm(&self) -> &Arc<Comm> {
|
||||||
|
&self.comm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomOp1 for AllReduce {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"neuron.tp.all_reduce"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
|
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||||
|
// Reject non-contiguous inputs explicitly — copying them
|
||||||
|
// server-side would mask shape bugs (a TP layer feeding a
|
||||||
|
// strided activation into all_reduce is almost certainly a
|
||||||
|
// model construction error).
|
||||||
|
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
|
||||||
|
slice: &cudarc::driver::CudaSlice<T>,
|
||||||
|
l: &Layout,
|
||||||
|
) -> Result<()> {
|
||||||
|
match l.contiguous_offsets() {
|
||||||
|
Some((0, n)) if n == slice.len() => Ok(()),
|
||||||
|
_ => candle_core::bail!(
|
||||||
|
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
|
||||||
|
l,
|
||||||
|
slice.len()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let elem_count = l.shape().elem_count();
|
||||||
|
let dev = s.device().clone();
|
||||||
|
|
||||||
|
let out = match s.dtype() {
|
||||||
|
DType::BF16 => {
|
||||||
|
let src = s.as_cuda_slice::<bf16>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let src = s.as_cuda_slice::<f16>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let src = s.as_cuda_slice::<f32>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
dtype => candle_core::bail!(
|
||||||
|
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
Ok((out, l.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
372
crates/neuron/src/harness/tp/mod.rs
Normal file
372
crates/neuron/src/harness/tp/mod.rs
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
//! Tensor-parallel inference plumbing.
|
||||||
|
//!
|
||||||
|
//! The leader process (the neuron daemon proper) drives one
|
||||||
|
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
|
||||||
|
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
|
||||||
|
//! and talks to each over a newline-delimited JSON RPC channel on
|
||||||
|
//! the worker's stdin/stdout (see `rpc.rs`).
|
||||||
|
//!
|
||||||
|
//! Sub-staging:
|
||||||
|
//!
|
||||||
|
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
|
||||||
|
//! forks N workers; `ping` round-trips every worker to confirm
|
||||||
|
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
|
||||||
|
//! `NcclSanityCheck` are stubbed.
|
||||||
|
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
|
||||||
|
//! `NcclSanityCheck`. CUDA-gated.
|
||||||
|
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
||||||
|
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
||||||
|
|
||||||
|
pub mod all_reduce;
|
||||||
|
pub mod nccl_state;
|
||||||
|
pub mod rpc;
|
||||||
|
pub mod tp_linear;
|
||||||
|
pub mod tp_qwen3;
|
||||||
|
pub mod worker;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::Stdio;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
|
||||||
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||||
|
|
||||||
|
use rpc::{WorkerRequest, WorkerResponse};
|
||||||
|
|
||||||
|
/// One worker subprocess plus its bidirectional stdio handles.
|
||||||
|
struct Worker {
|
||||||
|
rank: u32,
|
||||||
|
/// Captured so the leader can log "spawned rank N on device M" and
|
||||||
|
/// future stages can re-issue Init after a CUDA reset. Unused in
|
||||||
|
/// the Stage 7a-i RPC paths themselves.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
cuda_device: u32,
|
||||||
|
child: Child,
|
||||||
|
stdin: ChildStdin,
|
||||||
|
stdout: Lines<BufReader<ChildStdout>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Worker {
|
||||||
|
/// Send a request and wait for the response. Used for sequenced
|
||||||
|
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
|
||||||
|
/// overlap the worker's execution with the leader's.
|
||||||
|
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
|
||||||
|
self.send_only(req).await?;
|
||||||
|
self.recv_only().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write a request without awaiting its response. Pair with
|
||||||
|
/// `recv_only` from the caller when leader and worker need to do
|
||||||
|
/// work concurrently — e.g. during `Init`, where the leader
|
||||||
|
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
|
||||||
|
/// workers, then collects `InitOk` after NCCL completes.
|
||||||
|
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
|
||||||
|
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
|
||||||
|
line.push('\n');
|
||||||
|
self.stdin
|
||||||
|
.write_all(line.as_bytes())
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("write request to rank {}", self.rank))?;
|
||||||
|
self.stdin
|
||||||
|
.flush()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn recv_only(&mut self) -> Result<WorkerResponse> {
|
||||||
|
let reply = self
|
||||||
|
.stdout
|
||||||
|
.next_line()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("read reply from rank {}", self.rank))?
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
|
||||||
|
serde_json::from_str(&reply)
|
||||||
|
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A live pool of worker subprocesses. Owns the `Child` handles so
|
||||||
|
/// dropping the pool kills the children; explicit `shutdown()` is
|
||||||
|
/// the graceful path.
|
||||||
|
pub struct WorkerPool {
|
||||||
|
world_size: u32,
|
||||||
|
workers: Vec<Worker>,
|
||||||
|
/// Path to the neuron binary used to launch workers.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
exe: PathBuf,
|
||||||
|
/// Leader's own NCCL rank-0 state. Defaults to empty; populated by
|
||||||
|
/// `init_nccl()`. Held here so the leader can participate in
|
||||||
|
/// collectives (rank 0) without spawning a fourth subprocess.
|
||||||
|
leader_nccl: nccl_state::NcclState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerPool {
|
||||||
|
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
||||||
|
/// leader (in-process) and is *not* spawned here — the leader
|
||||||
|
/// holds rank 0's NCCL Comm and shard in its own address space.
|
||||||
|
///
|
||||||
|
/// `binary` is the path to the neuron executable to run for each
|
||||||
|
/// worker (production passes `/proc/self/exe`; tests pass the
|
||||||
|
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
|
||||||
|
/// `cuda_devices` is one entry per rank including rank 0. Worker
|
||||||
|
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
|
||||||
|
pub async fn spawn(binary: &Path, world_size: u32, cuda_devices: &[u32]) -> Result<Self> {
|
||||||
|
if world_size < 2 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"WorkerPool::spawn called with world_size={world_size}; \
|
||||||
|
use the single-process path for world_size < 2"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if cuda_devices.len() as u32 != world_size {
|
||||||
|
anyhow::bail!(
|
||||||
|
"expected {world_size} cuda_devices entries, got {}",
|
||||||
|
cuda_devices.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let exe = binary.to_path_buf();
|
||||||
|
|
||||||
|
let mut workers = Vec::with_capacity(world_size as usize - 1);
|
||||||
|
// Rank 0 stays in-process. Spawn ranks 1..world_size.
|
||||||
|
for rank in 1..world_size {
|
||||||
|
let cuda_device = cuda_devices[rank as usize];
|
||||||
|
let mut cmd = Command::new(&exe);
|
||||||
|
cmd.arg("--worker")
|
||||||
|
.arg("--rank")
|
||||||
|
.arg(rank.to_string())
|
||||||
|
.arg("--tp-size")
|
||||||
|
.arg(world_size.to_string())
|
||||||
|
.arg("--cuda-device")
|
||||||
|
.arg(cuda_device.to_string())
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
// Inherit stderr so worker tracing surfaces alongside
|
||||||
|
// the leader's journalctl stream.
|
||||||
|
.stderr(Stdio::inherit())
|
||||||
|
.kill_on_drop(true);
|
||||||
|
|
||||||
|
let mut child = cmd
|
||||||
|
.spawn()
|
||||||
|
.with_context(|| format!("spawn worker rank {rank}"))?;
|
||||||
|
let stdin = child
|
||||||
|
.stdin
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
|
||||||
|
let stdout = child
|
||||||
|
.stdout
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
|
||||||
|
let stdout = BufReader::new(stdout).lines();
|
||||||
|
|
||||||
|
workers.push(Worker {
|
||||||
|
rank,
|
||||||
|
cuda_device,
|
||||||
|
child,
|
||||||
|
stdin,
|
||||||
|
stdout,
|
||||||
|
});
|
||||||
|
tracing::info!(rank, cuda_device, "spawned tp worker");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
world_size,
|
||||||
|
workers,
|
||||||
|
exe,
|
||||||
|
leader_nccl: nccl_state::NcclState::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Establish the NCCL communicator across the leader (rank 0) and
|
||||||
|
/// every worker subprocess. Rendezvous is via a freshly-generated
|
||||||
|
/// `Id` broadcast over the RPC stream; the actual handshake blocks
|
||||||
|
/// inside `Comm::from_rank` until all `world_size` ranks check in.
|
||||||
|
///
|
||||||
|
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
|
||||||
|
/// to — typically the first entry of the `cuda_devices` slice
|
||||||
|
/// originally passed to `spawn()`.
|
||||||
|
///
|
||||||
|
/// On the non-cuda build this immediately fails because the leader
|
||||||
|
/// can't generate an `Id` without libnccl. The same call works in
|
||||||
|
/// the worker path (returning a no-cuda error response) so the
|
||||||
|
/// failure surface is uniform.
|
||||||
|
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
|
||||||
|
let comm_id = nccl_state::generate_comm_id_hex()
|
||||||
|
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
|
||||||
|
|
||||||
|
// 1. Write Init to every worker's stdin without awaiting the
|
||||||
|
// response. Workers will parse and call Comm::from_rank
|
||||||
|
// concurrently with the leader below.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: comm_id.clone(),
|
||||||
|
};
|
||||||
|
w.send_only(&req).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader rank 0 calls Comm::from_rank on its own device.
|
||||||
|
// Runs on spawn_blocking because NCCL's init blocks until
|
||||||
|
// every rank has called in — that's exactly the workers
|
||||||
|
// above. The leader's NcclState is moved through the
|
||||||
|
// blocking task and returned to the pool.
|
||||||
|
let leader_cfg = worker::WorkerConfig {
|
||||||
|
rank: 0,
|
||||||
|
world_size: self.world_size,
|
||||||
|
cuda_device: leader_cuda_device,
|
||||||
|
};
|
||||||
|
let comm_id_for_leader = comm_id.clone();
|
||||||
|
// Swap out the leader's NcclState into a fresh empty one so we
|
||||||
|
// can move it into spawn_blocking; restore after the task
|
||||||
|
// returns. (NcclState isn't Clone — it owns a real NCCL Comm.)
|
||||||
|
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
||||||
|
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
||||||
|
let resp = leader_state.init(leader_cfg, &comm_id_for_leader);
|
||||||
|
(leader_state, resp)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context("leader NCCL init task panicked")?;
|
||||||
|
self.leader_nccl = returned_state;
|
||||||
|
match leader_resp {
|
||||||
|
rpc::WorkerResponse::InitOk => {}
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Read InitOk from each worker. By now every worker has
|
||||||
|
// completed its Comm::from_rank call (NCCL released them
|
||||||
|
// when the leader joined the handshake) and is writing its
|
||||||
|
// response.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match &resp {
|
||||||
|
rpc::WorkerResponse::InitOk => {}
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} init: expected InitOk, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::info!(
|
||||||
|
world_size = self.world_size,
|
||||||
|
"NCCL communicator established across all ranks"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate the NCCL communicator: every rank `all_reduce`s a
|
||||||
|
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
|
||||||
|
/// `world_size`. Confirms the handshake is live, not just
|
||||||
|
/// configured.
|
||||||
|
///
|
||||||
|
/// Must be called after `init_nccl()`; before that the leader has
|
||||||
|
/// no Comm and the workers reply with `nccl_not_initialised`.
|
||||||
|
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
|
||||||
|
// 1. Trigger the all_reduce on every worker (write-only).
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader's own all_reduce, in spawn_blocking. NCCL operations
|
||||||
|
// block until every rank participates.
|
||||||
|
let mut leader_state = std::mem::take(&mut self.leader_nccl);
|
||||||
|
let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || {
|
||||||
|
let resp = leader_state.sanity_check();
|
||||||
|
(leader_state, resp)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.context("leader NCCL sanity task panicked")?;
|
||||||
|
self.leader_nccl = returned_state;
|
||||||
|
|
||||||
|
let expected = self.world_size;
|
||||||
|
let leader_sum = match leader_resp {
|
||||||
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
|
||||||
|
};
|
||||||
|
if leader_sum != expected {
|
||||||
|
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Read sanity result from each worker. All must match
|
||||||
|
// world_size — anything else means the collective didn't
|
||||||
|
// complete consistently across ranks.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
rpc::WorkerResponse::NcclSanityResult { observed_sum }
|
||||||
|
if observed_sum == expected => {}
|
||||||
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||||
|
anyhow::bail!(
|
||||||
|
"worker rank {} observed_sum={observed_sum}, expected {expected}",
|
||||||
|
w.rank
|
||||||
|
);
|
||||||
|
}
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::info!(
|
||||||
|
world_size = expected,
|
||||||
|
"NCCL sanity check OK across all ranks"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ping every worker and return their Pong payloads in rank order.
|
||||||
|
/// Useful right after `spawn` to confirm the lifecycle plumbing is
|
||||||
|
/// intact before kicking off any heavier work.
|
||||||
|
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
|
||||||
|
let mut out = Vec::with_capacity(self.workers.len());
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.request(&WorkerRequest::Ping).await?;
|
||||||
|
match &resp {
|
||||||
|
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
|
||||||
|
WorkerResponse::Pong { rank, .. } => {
|
||||||
|
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
|
||||||
|
}
|
||||||
|
out.push(resp);
|
||||||
|
}
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
||||||
|
/// children. Best-effort — individual worker failures are logged
|
||||||
|
/// but don't abort the rest of the sweep.
|
||||||
|
pub async fn shutdown(mut self) -> Result<()> {
|
||||||
|
for w in &mut self.workers {
|
||||||
|
match w.request(&WorkerRequest::Shutdown).await {
|
||||||
|
Ok(WorkerResponse::Bye) => {}
|
||||||
|
Ok(other) => tracing::warn!(
|
||||||
|
rank = w.rank,
|
||||||
|
response = ?other,
|
||||||
|
"expected Bye on shutdown"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for w in &mut self.workers {
|
||||||
|
match w.child.wait().await {
|
||||||
|
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
|
||||||
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn world_size(&self) -> u32 {
|
||||||
|
self.world_size
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_path(&self) -> &PathBuf {
|
||||||
|
&self.exe
|
||||||
|
}
|
||||||
|
}
|
||||||
243
crates/neuron/src/harness/tp/nccl_state.rs
Normal file
243
crates/neuron/src/harness/tp/nccl_state.rs
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
//! NCCL state held by both the worker process and the leader's pool.
|
||||||
|
//!
|
||||||
|
//! Split into its own module so the worker (`tp/worker.rs`) and the
|
||||||
|
//! leader (`tp/mod.rs`) share the same hex-encoding/decoding code and
|
||||||
|
//! the same shape of `Option<Comm>` state machine.
|
||||||
|
//!
|
||||||
|
//! When the `cuda` feature is off, `NcclState` is a zero-sized
|
||||||
|
//! placeholder that returns `Error{kind="cuda_feature_not_enabled"}`
|
||||||
|
//! from every operation. When it's on, the same struct holds the
|
||||||
|
//! actual `cudarc::nccl::Comm`.
|
||||||
|
|
||||||
|
use super::rpc::WorkerResponse;
|
||||||
|
use super::worker::WorkerConfig;
|
||||||
|
|
||||||
|
/// Encode bytes as lowercase hex. Used for ferrying NCCL `Id::internal()`
|
||||||
|
/// across the leader→worker RPC boundary inside a JSON string.
|
||||||
|
pub fn encode_hex(bytes: &[u8]) -> String {
|
||||||
|
let mut out = String::with_capacity(bytes.len() * 2);
|
||||||
|
for b in bytes {
|
||||||
|
use std::fmt::Write;
|
||||||
|
let _ = write!(out, "{b:02x}");
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode lowercase-or-uppercase hex into bytes. Errors on odd length
|
||||||
|
/// or non-hex characters; the caller bubbles those up via the RPC's
|
||||||
|
/// `Error{kind="bad_request"}` variant.
|
||||||
|
pub fn decode_hex(s: &str) -> Result<Vec<u8>, String> {
|
||||||
|
if !s.len().is_multiple_of(2) {
|
||||||
|
return Err(format!("hex string has odd length {}", s.len()));
|
||||||
|
}
|
||||||
|
(0..s.len())
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| {
|
||||||
|
u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("bad hex byte at {i}: {e}"))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub struct NcclState;
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
impl Default for NcclState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
impl NcclState {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, _cfg: WorkerConfig, _comm_id_hex: &str) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "this neuron binary was built without --features cuda; \
|
||||||
|
NCCL Init requires CUDA"
|
||||||
|
.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sanity_check(&mut self) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "NCCL sanity check requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
mod cuda_impl {
|
||||||
|
use super::*;
|
||||||
|
use cudarc::driver::CudaContext;
|
||||||
|
use cudarc::nccl::{Comm, Id, ReduceOp};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Number of bytes in NCCL's unique-id type; matches `Id::internal()`'s
|
||||||
|
/// `[c_char; 128]`. Wire-encoded as 256 lowercase hex chars.
|
||||||
|
const NCCL_ID_BYTES: usize = 128;
|
||||||
|
|
||||||
|
pub struct NcclState {
|
||||||
|
comm: Option<Comm>,
|
||||||
|
/// Held alongside the Comm so the device isn't dropped
|
||||||
|
/// underneath the NCCL handle.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ctx: Option<Arc<CudaContext>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for NcclState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NcclState {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
comm: None,
|
||||||
|
ctx: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
||||||
|
// (libnccl-allocated state). NCCL requires that operations against
|
||||||
|
// one Comm be issued one at a time; we serialise access by storing
|
||||||
|
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
|
||||||
|
// move-safe — NCCL doesn't track the calling OS thread, only the
|
||||||
|
// stream the operations are dispatched against.
|
||||||
|
unsafe impl Send for NcclState {}
|
||||||
|
unsafe impl Sync for NcclState {}
|
||||||
|
|
||||||
|
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
|
||||||
|
/// the leader to mint the shared communicator id which is then
|
||||||
|
/// broadcast to every worker via the RPC `Init` message.
|
||||||
|
pub fn generate_comm_id_hex() -> Result<String, String> {
|
||||||
|
// NcclError lacks a Display impl in cudarc 0.19.x — surface
|
||||||
|
// via Debug throughout this module.
|
||||||
|
let id = Id::new().map_err(|e| format!("Id::new(): {e:?}"))?;
|
||||||
|
let bytes_u8: [u8; NCCL_ID_BYTES] = std::array::from_fn(|i| id.internal()[i] as u8);
|
||||||
|
Ok(encode_hex(&bytes_u8))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NcclState {
|
||||||
|
pub fn init(&mut self, cfg: WorkerConfig, comm_id_hex: &str) -> WorkerResponse {
|
||||||
|
match try_init(self, cfg, comm_id_hex) {
|
||||||
|
Ok(()) => WorkerResponse::InitOk,
|
||||||
|
Err(msg) => WorkerResponse::Error {
|
||||||
|
kind: "nccl_init_failed".into(),
|
||||||
|
message: msg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sanity_check(&mut self) -> WorkerResponse {
|
||||||
|
let Some(comm) = self.comm.as_ref() else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "nccl_not_initialised".into(),
|
||||||
|
message: "sanity_check requires Init to have completed first".into(),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
match try_sanity_check(comm) {
|
||||||
|
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
|
||||||
|
Err(msg) => WorkerResponse::Error {
|
||||||
|
kind: "nccl_sanity_failed".into(),
|
||||||
|
message: msg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_init(state: &mut NcclState, cfg: WorkerConfig, comm_id_hex: &str) -> Result<(), String> {
|
||||||
|
let bytes = decode_hex(comm_id_hex)?;
|
||||||
|
if bytes.len() != NCCL_ID_BYTES {
|
||||||
|
return Err(format!(
|
||||||
|
"comm_id is {} bytes, expected {NCCL_ID_BYTES}",
|
||||||
|
bytes.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let id_bytes: [std::ffi::c_char; NCCL_ID_BYTES] =
|
||||||
|
std::array::from_fn(|i| bytes[i] as std::ffi::c_char);
|
||||||
|
let id = Id::uninit(id_bytes);
|
||||||
|
|
||||||
|
let ctx = CudaContext::new(cfg.cuda_device as usize)
|
||||||
|
.map_err(|e| format!("CudaContext::new({}) failed: {e}", cfg.cuda_device))?;
|
||||||
|
let stream = ctx.default_stream();
|
||||||
|
let comm = Comm::from_rank(stream, cfg.rank as usize, cfg.world_size as usize, id)
|
||||||
|
.map_err(|e| {
|
||||||
|
format!(
|
||||||
|
"Comm::from_rank(rank={}, world={}) failed: {e:?}",
|
||||||
|
cfg.rank, cfg.world_size
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
state.ctx = Some(ctx);
|
||||||
|
state.comm = Some(comm);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_sanity_check(comm: &Comm) -> Result<u32, String> {
|
||||||
|
let stream = comm.stream().clone();
|
||||||
|
let input = stream
|
||||||
|
.clone_htod(&[1u32])
|
||||||
|
.map_err(|e| format!("htod sentinel: {e}"))?;
|
||||||
|
let mut output = stream
|
||||||
|
.alloc_zeros::<u32>(1)
|
||||||
|
.map_err(|e| format!("alloc output: {e}"))?;
|
||||||
|
// cudarc::nccl::NcclError doesn't impl Display in 0.19.x —
|
||||||
|
// surface via Debug so we still see the variant + ncclResult
|
||||||
|
// code instead of a generic "{e}" failure.
|
||||||
|
comm.all_reduce(&input, &mut output, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| format!("all_reduce: {e:?}"))?;
|
||||||
|
let result = stream
|
||||||
|
.clone_dtoh(&output)
|
||||||
|
.map_err(|e| format!("dtoh result: {e}"))?;
|
||||||
|
Ok(result[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub use cuda_impl::{NcclState, generate_comm_id_hex};
|
||||||
|
|
||||||
|
/// Non-cuda stub for the leader: returns a clear marker error rather
|
||||||
|
/// than letting `init_nccl` succeed vacuously.
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn generate_comm_id_hex() -> Result<String, String> {
|
||||||
|
Err("cuda_feature_not_enabled: build with --features cuda".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_roundtrip() {
|
||||||
|
let original: Vec<u8> = (0u8..=255).collect();
|
||||||
|
let encoded = encode_hex(&original);
|
||||||
|
assert_eq!(encoded.len(), 512);
|
||||||
|
let decoded = decode_hex(&encoded).expect("decode");
|
||||||
|
assert_eq!(decoded, original);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_decode_rejects_odd_length() {
|
||||||
|
assert!(decode_hex("a").is_err());
|
||||||
|
assert!(decode_hex("abc").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_decode_rejects_non_hex() {
|
||||||
|
assert!(decode_hex("zz").is_err());
|
||||||
|
assert!(decode_hex("ab_d").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_encode_is_lowercase_padded() {
|
||||||
|
assert_eq!(encode_hex(&[0x0a, 0xff]), "0aff");
|
||||||
|
}
|
||||||
|
}
|
||||||
186
crates/neuron/src/harness/tp/rpc.rs
Normal file
186
crates/neuron/src/harness/tp/rpc.rs
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
//! Wire protocol between the neuron leader process and its
|
||||||
|
//! `--worker` subprocesses.
|
||||||
|
//!
|
||||||
|
//! Every frame is one newline-delimited JSON object on the worker's
|
||||||
|
//! stdin (request) or stdout (response). Both directions are tagged
|
||||||
|
//! sum types from the start so new ops in Stage 7b/7c slot in without
|
||||||
|
//! breaking compatibility — no "14 message types and a version field"
|
||||||
|
//! drift later. Adding a new variant is the canonical way to evolve
|
||||||
|
//! the protocol; existing peers that don't recognise an op return
|
||||||
|
//! `WorkerResponse::Error { kind: "unknown_op", .. }`.
|
||||||
|
//!
|
||||||
|
//! The serialised shape uses `tag = "op"` so a request looks like:
|
||||||
|
//! {"op":"ping"}
|
||||||
|
//! {"op":"init","comm_id":"a1b2..."}
|
||||||
|
//! and a response:
|
||||||
|
//! {"op":"pong","rank":0,"world_size":2,"cuda_device":0}
|
||||||
|
//! {"op":"error","kind":"nccl_init_failed","message":"..."}
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Leader → worker. Worker handles one at a time; replies with exactly
|
||||||
|
/// one `WorkerResponse` per request.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "op", rename_all = "snake_case")]
|
||||||
|
pub enum WorkerRequest {
|
||||||
|
/// Liveness probe. Worker replies with `Pong` containing its own
|
||||||
|
/// identity. Used by the leader to confirm the subprocess is up
|
||||||
|
/// and ready before kicking off any heavier work.
|
||||||
|
Ping,
|
||||||
|
|
||||||
|
/// One-shot NCCL communicator setup. The leader generates the
|
||||||
|
/// `comm_id` once (rank 0 of NCCL), broadcasts it to every worker
|
||||||
|
/// via this message, then every rank (leader included) calls
|
||||||
|
/// `Comm::from_rank` with the same id — NCCL blocks until all
|
||||||
|
/// `world_size` ranks check in. The hex-encoded bytes are the
|
||||||
|
/// canonical `cudarc::nccl::Id::internal()` content.
|
||||||
|
Init {
|
||||||
|
/// Hex-encoded NCCL id bytes (128 bytes → 256 hex chars).
|
||||||
|
comm_id: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Sanity check: after Init, every rank runs an `all_reduce` over
|
||||||
|
/// a sentinel value (`1u32`). The expected sum is `world_size`.
|
||||||
|
/// Worker replies with the observed value so the leader can verify
|
||||||
|
/// the NCCL handshake is genuinely live, not just configured.
|
||||||
|
NcclSanityCheck,
|
||||||
|
|
||||||
|
/// Worker should release resources and exit. Worker replies `Bye`
|
||||||
|
/// and then closes stdout / exits zero. The leader reaps the
|
||||||
|
/// child via the `tokio::process::Child` it kept.
|
||||||
|
Shutdown,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Worker → leader. Always exactly one of these per `WorkerRequest`.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "op", rename_all = "snake_case")]
|
||||||
|
pub enum WorkerResponse {
|
||||||
|
/// Reply to `Ping`. Carries enough identity for the leader to log
|
||||||
|
/// what it actually got back.
|
||||||
|
Pong {
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
cuda_device: u32,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Reply to `Init`. Empty payload — success is the absence of
|
||||||
|
/// `Error`. NCCL's internal blocking handshake means by the time
|
||||||
|
/// this comes back, every other rank has also reached
|
||||||
|
/// `Comm::from_rank`.
|
||||||
|
InitOk,
|
||||||
|
|
||||||
|
/// Reply to `NcclSanityCheck`. The observed sum after a single
|
||||||
|
/// `all_reduce(SUM, 1u32)` across all ranks. The leader checks
|
||||||
|
/// this matches `world_size`.
|
||||||
|
NcclSanityResult { observed_sum: u32 },
|
||||||
|
|
||||||
|
/// Reply to `Shutdown`. Worker exits immediately after writing this.
|
||||||
|
Bye,
|
||||||
|
|
||||||
|
/// Any request can produce this instead of its dedicated success
|
||||||
|
/// variant. `kind` is a machine-readable category so the leader
|
||||||
|
/// can branch on failure mode without string-matching `message`.
|
||||||
|
Error {
|
||||||
|
/// Short tag — `nccl_init_failed`, `unknown_op`, etc.
|
||||||
|
kind: String,
|
||||||
|
/// Human-readable detail for logs.
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn roundtrip<T>(value: &T) -> T
|
||||||
|
where
|
||||||
|
T: Serialize + for<'de> Deserialize<'de>,
|
||||||
|
{
|
||||||
|
serde_json::from_str(&serde_json::to_string(value).expect("serialise"))
|
||||||
|
.expect("deserialise")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_ping_round_trip() {
|
||||||
|
let req = WorkerRequest::Ping;
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert_eq!(wire, r#"{"op":"ping"}"#);
|
||||||
|
match roundtrip(&req) {
|
||||||
|
WorkerRequest::Ping => {}
|
||||||
|
other => panic!("expected Ping, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_init_carries_hex_id() {
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: "deadbeef".into(),
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_shutdown_round_trip() {
|
||||||
|
assert_eq!(
|
||||||
|
serde_json::to_string(&WorkerRequest::Shutdown).unwrap(),
|
||||||
|
r#"{"op":"shutdown"}"#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_pong_round_trip() {
|
||||||
|
let resp = WorkerResponse::Pong {
|
||||||
|
rank: 1,
|
||||||
|
world_size: 4,
|
||||||
|
cuda_device: 1,
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&resp).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"pong""#));
|
||||||
|
assert!(wire.contains(r#""rank":1"#));
|
||||||
|
assert!(wire.contains(r#""world_size":4"#));
|
||||||
|
match roundtrip(&resp) {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
assert_eq!(rank, 1);
|
||||||
|
assert_eq!(world_size, 4);
|
||||||
|
assert_eq!(cuda_device, 1);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_error_carries_kind_and_message() {
|
||||||
|
let resp = WorkerResponse::Error {
|
||||||
|
kind: "nccl_init_failed".into(),
|
||||||
|
message: "could not bind device".into(),
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&resp).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"error""#));
|
||||||
|
assert!(wire.contains(r#""kind":"nccl_init_failed""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_sanity_result_round_trip() {
|
||||||
|
let resp = WorkerResponse::NcclSanityResult { observed_sum: 4 };
|
||||||
|
match roundtrip(&resp) {
|
||||||
|
WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||||
|
assert_eq!(observed_sum, 4);
|
||||||
|
}
|
||||||
|
other => panic!("expected NcclSanityResult, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unknown ops on the wire deserialise to an error rather than
|
||||||
|
/// silently mis-matching — confirms our `serde(tag = "op")`
|
||||||
|
/// configuration rejects unknowns instead of doing fuzzy matching.
|
||||||
|
#[test]
|
||||||
|
fn unknown_op_fails_to_parse() {
|
||||||
|
let result: Result<WorkerRequest, _> = serde_json::from_str(r#"{"op":"explode"}"#);
|
||||||
|
assert!(result.is_err(), "should reject unknown op, got {result:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
134
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
134
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
|
||||||
|
//! and `Shard` sharding hints.
|
||||||
|
//!
|
||||||
|
//! candle reads only the rank's slice of each weight tensor from
|
||||||
|
//! safetensors via `view.slice(start..stop)` — no full-tensor host
|
||||||
|
//! materialisation. That's a memory-efficiency win over hand-rolled
|
||||||
|
//! "load full + narrow" sharding (which the earlier
|
||||||
|
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
|
||||||
|
//!
|
||||||
|
//! Two layer types:
|
||||||
|
//!
|
||||||
|
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
|
||||||
|
//! local matmul. The downstream consumer either accepts a sharded
|
||||||
|
//! activation (next layer is also column-parallel) or all-gathers.
|
||||||
|
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
|
||||||
|
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
|
||||||
|
//!
|
||||||
|
//! Both assume **no bias** — every Qwen3-family weight layout we
|
||||||
|
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
|
||||||
|
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
|
||||||
|
//! support is mechanical when a future model needs it; the design
|
||||||
|
//! choice would be: column-parallel shards the bias along dim 0;
|
||||||
|
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
|
||||||
|
//! sum carries it exactly once.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use super::all_reduce::AllReduce;
|
||||||
|
|
||||||
|
/// Helper to build a [`Shard`] hint for a given dimension.
|
||||||
|
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
|
||||||
|
Shard {
|
||||||
|
dim,
|
||||||
|
rank: rank as usize,
|
||||||
|
world_size: world_size as usize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Output-dim sharded linear (column-parallel). Holds a standard
|
||||||
|
/// `candle_nn::Linear` whose `weight` is the rank's slice of the full
|
||||||
|
/// `[out_features, in_features]` tensor along dim 0.
|
||||||
|
pub struct ColumnParallelLinear {
|
||||||
|
inner: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ColumnParallelLinear {
|
||||||
|
/// Load this rank's column-parallel slice from a
|
||||||
|
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
|
||||||
|
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(0, rank, world_size))
|
||||||
|
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner: Linear::new(weight, None),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ColumnParallelLinear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Input-dim sharded linear (row-parallel).
|
||||||
|
///
|
||||||
|
/// Holds a sharded `Linear` plus an `AllReduce` op the forward chains
|
||||||
|
/// after the local matmul to recover the full activation.
|
||||||
|
pub struct RowParallelLinear {
|
||||||
|
inner: Linear,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
all_reduce: AllReduce,
|
||||||
|
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
|
||||||
|
/// is a pair: the column output is sharded, the row input is
|
||||||
|
/// sharded, and the AllReduce gives back the full output. For
|
||||||
|
/// `world_size = 1` the AllReduce is a no-op so we skip it.
|
||||||
|
needs_reduce: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RowParallelLinear {
|
||||||
|
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
|
||||||
|
///
|
||||||
|
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
|
||||||
|
/// `AllReduce` runs against. On CPU builds the parameter is
|
||||||
|
/// elided — forward returns the partial sum, which is the *wrong*
|
||||||
|
/// answer for inference but lets us compile-check the model.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner: Linear::new(weight, None),
|
||||||
|
all_reduce: AllReduce::new(comm),
|
||||||
|
needs_reduce: world_size > 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner: Linear::new(weight, None),
|
||||||
|
needs_reduce: world_size > 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for RowParallelLinear {
|
||||||
|
/// Local matmul followed by an `AllReduce` (when `cuda` and
|
||||||
|
/// `world_size > 1`). On CPU or single-rank, returns the partial
|
||||||
|
/// output directly — which is *only* correct for `world_size == 1`.
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let local = self.inner.forward(x)?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
if self.needs_reduce {
|
||||||
|
return local.apply_op1_no_bwd(&self.all_reduce);
|
||||||
|
}
|
||||||
|
let _ = self.needs_reduce;
|
||||||
|
Ok(local)
|
||||||
|
}
|
||||||
|
}
|
||||||
605
crates/neuron/src/harness/tp/tp_qwen3.rs
Normal file
605
crates/neuron/src/harness/tp/tp_qwen3.rs
Normal file
@@ -0,0 +1,605 @@
|
|||||||
|
//! Tensor-parallel Qwen3 dense model.
|
||||||
|
//!
|
||||||
|
//! Mirrors `candle_transformers::models::qwen3` structurally, but with:
|
||||||
|
//!
|
||||||
|
//! - Attention's `q_proj` / `k_proj` / `v_proj` as
|
||||||
|
//! [`ColumnParallelLinear`] (output sharded along the head dimension —
|
||||||
|
//! per-rank `num_heads = total/world_size`, ditto for kv heads).
|
||||||
|
//! - Attention's `o_proj` as [`RowParallelLinear`] (input sharded; the
|
||||||
|
//! trailing `AllReduce` recovers the full activation).
|
||||||
|
//! - MLP's `gate_proj` / `up_proj` as [`ColumnParallelLinear`] (sharded
|
||||||
|
//! along `intermediate_size`).
|
||||||
|
//! - MLP's `down_proj` as [`RowParallelLinear`].
|
||||||
|
//! - `embed_tokens`, all `RmsNorm`s, and `lm_head` **replicated** on
|
||||||
|
//! every rank. The per-rank duplicate weight is bounded and lets us
|
||||||
|
//! skip the embedding all-gather and the lm-head column-shard +
|
||||||
|
//! all-gather; both are pure latency optimisations that don't change
|
||||||
|
//! correctness.
|
||||||
|
//! - `kv_cache` holds the per-rank slice of K/V already (because they
|
||||||
|
//! came out of a column-parallel projection). No cache resharding
|
||||||
|
//! needed across ranks.
|
||||||
|
//!
|
||||||
|
//! Divisibility requirement, checked at load time:
|
||||||
|
//!
|
||||||
|
//! - `num_attention_heads % world_size == 0`
|
||||||
|
//! - `num_key_value_heads % world_size == 0`
|
||||||
|
//! - `intermediate_size % world_size == 0`
|
||||||
|
//!
|
||||||
|
//! Anything else bails — the safetensors slice would lose data otherwise.
|
||||||
|
//! This is the same divisibility-bail pattern that landed in
|
||||||
|
//! `EricLBuehler/mistral.rs` PR #2054.
|
||||||
|
//!
|
||||||
|
//! Replicated tensors (norms, embedding, lm_head) are loaded by asking
|
||||||
|
//! the `ShardedVarBuilder` for the full tensor via `vb.get(shape, name)`
|
||||||
|
//! — which defaults to `Shard { world_size: 1 }` and falls through to
|
||||||
|
//! the unsharded backend path.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, bail};
|
||||||
|
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use candle_nn::{Activation, Embedding, Linear, RmsNorm, kv_cache::ConcatKvCache};
|
||||||
|
use candle_transformers::utils::repeat_kv;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use cudarc::nccl::Comm;
|
||||||
|
|
||||||
|
use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
|
||||||
|
|
||||||
|
pub use candle_transformers::models::qwen3::Config;
|
||||||
|
|
||||||
|
/// Replicated rotary-embedding lookup. Re-implementation of the
|
||||||
|
/// `pub(crate)` candle equivalent — we can't reach into the upstream
|
||||||
|
/// type, so the inv-freq / sin / cos construction lives here.
|
||||||
|
pub(crate) struct Qwen3RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3RotaryEmbedding {
|
||||||
|
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let dim = cfg.head_dim;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let (_, _, seq_len, _) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper: load a replicated tensor by asking the ShardedVarBuilder for
|
||||||
|
/// the full tensor (world_size=1 hint falls through to SimpleBackend).
|
||||||
|
fn load_replicated<S: Into<candle_core::Shape>>(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
shape: S,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
vb.get(shape, name)
|
||||||
|
.with_context(|| format!("load replicated '{}/{name}'", vb.prefix()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_rms_norm(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<RmsNorm> {
|
||||||
|
let weight = load_replicated(vb, size, "weight")?;
|
||||||
|
Ok(RmsNorm::new(weight, eps))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP MLP. SwiGLU = `down(silu(gate(x)) * up(x))`.
|
||||||
|
pub(crate) struct TpQwen3MLP {
|
||||||
|
gate_proj: ColumnParallelLinear,
|
||||||
|
up_proj: ColumnParallelLinear,
|
||||||
|
down_proj: RowParallelLinear,
|
||||||
|
act_fn: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3MLP {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||||
|
bail!(
|
||||||
|
"intermediate_size {} not divisible by world_size {}",
|
||||||
|
cfg.intermediate_size,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||||
|
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||||
|
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||||
|
bail!(
|
||||||
|
"intermediate_size {} not divisible by world_size {}",
|
||||||
|
cfg.intermediate_size,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||||
|
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||||
|
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TpQwen3MLP {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = x.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP attention. Carries per-rank head counts and the q/k per-head
|
||||||
|
/// RmsNorms (which are replicated and operate on a flattened B*H*L
|
||||||
|
/// axis, so the same code path works irrespective of how H was split).
|
||||||
|
pub(crate) struct TpQwen3Attention {
|
||||||
|
q_proj: ColumnParallelLinear,
|
||||||
|
k_proj: ColumnParallelLinear,
|
||||||
|
v_proj: ColumnParallelLinear,
|
||||||
|
o_proj: RowParallelLinear,
|
||||||
|
q_norm: RmsNorm,
|
||||||
|
k_norm: RmsNorm,
|
||||||
|
local_num_heads: usize,
|
||||||
|
local_num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
local_hidden_size: usize,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
kv_cache: ConcatKvCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3Attention {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::load_inner(
|
||||||
|
cfg,
|
||||||
|
rotary_emb,
|
||||||
|
vb,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
comm,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::load_inner(cfg, rotary_emb, vb, rank, world_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_inner(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if cfg.use_sliding_window {
|
||||||
|
bail!("sliding window is not yet supported in the TP path");
|
||||||
|
}
|
||||||
|
if cfg.attention_bias {
|
||||||
|
bail!("attention_bias=true is not supported by ColumnParallel/RowParallelLinear yet");
|
||||||
|
}
|
||||||
|
let ws = world_size as usize;
|
||||||
|
if !cfg.num_attention_heads.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"num_attention_heads {} not divisible by world_size {}",
|
||||||
|
cfg.num_attention_heads,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !cfg.num_key_value_heads.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"num_key_value_heads {} not divisible by world_size {}",
|
||||||
|
cfg.num_key_value_heads,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let local_num_heads = cfg.num_attention_heads / ws;
|
||||||
|
let local_num_kv_heads = cfg.num_key_value_heads / ws;
|
||||||
|
let num_kv_groups = local_num_heads / local_num_kv_heads;
|
||||||
|
let local_hidden_size = head_dim * local_num_heads;
|
||||||
|
|
||||||
|
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?;
|
||||||
|
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?;
|
||||||
|
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?;
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?;
|
||||||
|
|
||||||
|
let q_norm = load_rms_norm(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
let k_norm = load_rms_norm(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
// dim=2 because we cat along the seq axis of (B, H, L, D) tensors.
|
||||||
|
let kv_cache = ConcatKvCache::new(2);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
local_num_heads,
|
||||||
|
local_num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
local_hidden_size,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
|
// 1. Projections (column-parallel → output is sharded).
|
||||||
|
let q = self.q_proj.forward(x)?;
|
||||||
|
let k = self.k_proj.forward(x)?;
|
||||||
|
let v = self.v_proj.forward(x)?;
|
||||||
|
|
||||||
|
// 2. Reshape: (B, L, H, D) → (B, H, L, D).
|
||||||
|
let q = q
|
||||||
|
.reshape((b, l, self.local_num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let k = k
|
||||||
|
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = v
|
||||||
|
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
// 3. Per-head RmsNorm (replicated weight, flat input).
|
||||||
|
let q_flat = q.flatten(0, 2)?;
|
||||||
|
let k_flat = k.flatten(0, 2)?;
|
||||||
|
let q_flat = self.q_norm.forward(&q_flat)?;
|
||||||
|
let k_flat = self.k_norm.forward(&k_flat)?;
|
||||||
|
let q = q_flat.reshape((b, self.local_num_heads, l, self.head_dim))?;
|
||||||
|
let k = k_flat.reshape((b, self.local_num_kv_heads, l, self.head_dim))?;
|
||||||
|
|
||||||
|
// 4. Rotary.
|
||||||
|
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
|
||||||
|
|
||||||
|
// 5. Accumulate KV.
|
||||||
|
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||||
|
|
||||||
|
// 6. GQA repeat_kv on the rank-local K/V.
|
||||||
|
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||||
|
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
|
// 7. Attention scores.
|
||||||
|
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||||
|
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
if let Some(m) = attn_mask {
|
||||||
|
scores = scores.broadcast_add(m)?;
|
||||||
|
}
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
let ctx = probs.matmul(&v)?;
|
||||||
|
|
||||||
|
// 8. Output projection (row-parallel → AllReduce inside).
|
||||||
|
ctx.transpose(1, 2)?
|
||||||
|
.reshape((b, l, self.local_hidden_size))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TpDecoderLayer {
|
||||||
|
self_attn: TpQwen3Attention,
|
||||||
|
mlp: TpQwen3MLP,
|
||||||
|
ln1: RmsNorm,
|
||||||
|
ln2: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpDecoderLayer {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn = TpQwen3Attention::load(
|
||||||
|
cfg,
|
||||||
|
rotary_emb,
|
||||||
|
&vb.pp("self_attn"),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
comm.clone(),
|
||||||
|
)?;
|
||||||
|
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?;
|
||||||
|
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
let ln2 = load_rms_norm(
|
||||||
|
&vb.pp("post_attention_layernorm"),
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn =
|
||||||
|
TpQwen3Attention::load(cfg, rotary_emb, &vb.pp("self_attn"), rank, world_size)?;
|
||||||
|
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?;
|
||||||
|
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
let ln2 = load_rms_norm(
|
||||||
|
&vb.pp("post_attention_layernorm"),
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let h = self.ln1.forward(x)?;
|
||||||
|
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||||
|
let x = (x + h)?;
|
||||||
|
let h2 = self.ln2.forward(&x)?;
|
||||||
|
let h2 = h2.apply(&self.mlp)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Base TP Qwen3 transformer — embedding, decoder stack, final norm.
|
||||||
|
/// The lm_head sits on top in [`TpQwen3ForCausalLM`].
|
||||||
|
pub struct TpQwen3Model {
|
||||||
|
embed_tokens: Embedding,
|
||||||
|
layers: Vec<TpDecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3Model {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||||
|
|
||||||
|
let embed_vb = vb.pp("model.embed_tokens");
|
||||||
|
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
|
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let vb_l = vb.pp("model.layers");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(TpDecoderLayer::load(
|
||||||
|
cfg,
|
||||||
|
rotary.clone(),
|
||||||
|
&vb_l.pp(i),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
comm.clone(),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||||
|
|
||||||
|
let embed_vb = vb.pp("model.embed_tokens");
|
||||||
|
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
|
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let vb_l = vb.pp("model.layers");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(TpDecoderLayer::load(
|
||||||
|
cfg,
|
||||||
|
rotary.clone(),
|
||||||
|
&vb_l.pp(i),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_weight(&self) -> &Tensor {
|
||||||
|
self.embed_tokens.embeddings()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for l in &mut self.layers {
|
||||||
|
l.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
|
};
|
||||||
|
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP Qwen3 with a (replicated) language-model head on top.
|
||||||
|
pub struct TpQwen3ForCausalLM {
|
||||||
|
base: TpQwen3Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3ForCausalLM {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let base = TpQwen3Model::load(cfg, vb, rank, world_size, comm)?;
|
||||||
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let base = TpQwen3Model::load(cfg, vb, rank, world_size)?;
|
||||||
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
let hidden = self.base.forward(input, offset)?;
|
||||||
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base.clear_kv_cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &Device {
|
||||||
|
&self.base.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> DType {
|
||||||
|
self.base.dtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_lm_head(cfg: &Config, vb: &ShardedVarBuilder, base: &TpQwen3Model) -> Result<Linear> {
|
||||||
|
if cfg.tie_word_embeddings {
|
||||||
|
Ok(Linear::new(base.embed_weight().clone(), None))
|
||||||
|
} else {
|
||||||
|
let weight = load_replicated(
|
||||||
|
&vb.pp("lm_head"),
|
||||||
|
(cfg.vocab_size, cfg.hidden_size),
|
||||||
|
"weight",
|
||||||
|
)?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
|
}
|
||||||
102
crates/neuron/src/harness/tp/worker.rs
Normal file
102
crates/neuron/src/harness/tp/worker.rs
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
//! Entry point for `neuron --worker`.
|
||||||
|
//!
|
||||||
|
//! The worker reads one newline-delimited JSON `WorkerRequest` from
|
||||||
|
//! stdin per loop iteration, dispatches synchronously, and writes
|
||||||
|
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
|
||||||
|
//! stderr so it doesn't collide with the RPC stream.
|
||||||
|
//!
|
||||||
|
//! NCCL operations (`Init`, `NcclSanityCheck`) are real when built
|
||||||
|
//! with the `cuda` feature; without it they reply with
|
||||||
|
//! `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
|
||||||
|
//! the difference between a misconfigured build and a genuine NCCL
|
||||||
|
//! failure.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
|
||||||
|
use super::nccl_state::NcclState;
|
||||||
|
use super::rpc::{WorkerRequest, WorkerResponse};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct WorkerConfig {
|
||||||
|
pub rank: u32,
|
||||||
|
pub world_size: u32,
|
||||||
|
pub cuda_device: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drive the worker RPC loop until `Shutdown` or EOF on stdin.
|
||||||
|
pub async fn run(config: WorkerConfig) -> Result<()> {
|
||||||
|
tracing::info!(
|
||||||
|
rank = config.rank,
|
||||||
|
world_size = config.world_size,
|
||||||
|
cuda_device = config.cuda_device,
|
||||||
|
"tp worker starting"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut state = WorkerState::new(config);
|
||||||
|
let stdin = tokio::io::stdin();
|
||||||
|
let mut reader = BufReader::new(stdin).lines();
|
||||||
|
let mut stdout = tokio::io::stdout();
|
||||||
|
|
||||||
|
while let Some(line) = reader.next_line().await? {
|
||||||
|
if line.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let req: WorkerRequest = match serde_json::from_str(&line) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let resp = WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse {line:?}: {e}"),
|
||||||
|
};
|
||||||
|
write_response(&mut stdout, &resp).await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = state.handle(req).await;
|
||||||
|
let is_bye = matches!(resp, WorkerResponse::Bye);
|
||||||
|
write_response(&mut stdout, &resp).await?;
|
||||||
|
if is_bye {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(rank = config.rank, "tp worker exiting");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -> Result<()> {
|
||||||
|
let mut line = serde_json::to_string(resp)?;
|
||||||
|
line.push('\n');
|
||||||
|
stdout.write_all(line.as_bytes()).await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WorkerState {
|
||||||
|
config: WorkerConfig,
|
||||||
|
nccl: NcclState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerState {
|
||||||
|
fn new(config: WorkerConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
nccl: NcclState::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(&mut self, req: WorkerRequest) -> WorkerResponse {
|
||||||
|
match req {
|
||||||
|
WorkerRequest::Ping => WorkerResponse::Pong {
|
||||||
|
rank: self.config.rank,
|
||||||
|
world_size: self.config.world_size,
|
||||||
|
cuda_device: self.config.cuda_device,
|
||||||
|
},
|
||||||
|
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
|
||||||
|
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
|
||||||
|
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,3 +3,4 @@ pub mod config;
|
|||||||
pub mod discovery;
|
pub mod discovery;
|
||||||
pub mod harness;
|
pub mod harness;
|
||||||
pub mod health;
|
pub mod health;
|
||||||
|
pub mod startup;
|
||||||
|
|||||||
@@ -1,21 +1,52 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health};
|
use neuron::{
|
||||||
|
api,
|
||||||
|
config::NeuronConfig,
|
||||||
|
discovery,
|
||||||
|
harness::{HarnessRegistry, tp},
|
||||||
|
health, startup,
|
||||||
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
/// Top-level CLI. The same binary runs as either the public neuron
|
||||||
|
/// daemon (default) or a tensor-parallel worker subprocess (when
|
||||||
|
/// `--worker` is set) spawned by the leader on the same host.
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "neuron")]
|
#[command(name = "neuron")]
|
||||||
#[command(about = "Per-node daemon for cortex inference clusters")]
|
#[command(about = "Per-node daemon for cortex inference clusters")]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Port to listen on (overrides config file).
|
/// Run in tensor-parallel worker mode. The leader process spawns
|
||||||
|
/// one of these per non-zero NCCL rank and drives it over
|
||||||
|
/// newline-delimited JSON on stdin/stdout. Worker mode skips
|
||||||
|
/// discovery, the HTTP listener, and the health poller — it's a
|
||||||
|
/// pure RPC loop.
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
worker: bool,
|
||||||
|
|
||||||
|
/// NCCL rank for worker mode. Ignored when `--worker` is not set.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
rank: u32,
|
||||||
|
|
||||||
|
/// Total NCCL world size for worker mode. Ignored when `--worker`
|
||||||
|
/// is not set.
|
||||||
|
#[arg(long, default_value_t = 1)]
|
||||||
|
tp_size: u32,
|
||||||
|
|
||||||
|
/// CUDA device index for worker mode. Ignored when `--worker` is
|
||||||
|
/// not set.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
cuda_device: u32,
|
||||||
|
|
||||||
|
/// Port to listen on (overrides config file). Daemon mode only.
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
port: Option<u16>,
|
port: Option<u16>,
|
||||||
|
|
||||||
/// Path to the neuron config file.
|
/// Path to the neuron config file. Daemon mode only.
|
||||||
#[arg(short, long, default_value = "neuron.toml")]
|
#[arg(short, long, default_value = "neuron.toml")]
|
||||||
config: String,
|
config: String,
|
||||||
}
|
}
|
||||||
@@ -23,6 +54,7 @@ struct Args {
|
|||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
.with_env_filter(
|
.with_env_filter(
|
||||||
EnvFilter::try_from_default_env()
|
EnvFilter::try_from_default_env()
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
|
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
|
||||||
@@ -31,12 +63,26 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
if args.worker {
|
||||||
|
return tp::worker::run(tp::worker::WorkerConfig {
|
||||||
|
rank: args.rank,
|
||||||
|
world_size: args.tp_size,
|
||||||
|
cuda_device: args.cuda_device,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
daemon(args).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn daemon(args: Args) -> Result<()> {
|
||||||
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
|
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
|
||||||
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
|
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
|
||||||
NeuronConfig::default()
|
NeuronConfig::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
let port = args.port.unwrap_or(cfg.port);
|
let port = args.port.unwrap_or(cfg.port);
|
||||||
|
let bind_url = format!("http://localhost:{port}");
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
tracing::info!("running hardware discovery");
|
tracing::info!("running hardware discovery");
|
||||||
@@ -47,9 +93,18 @@ async fn main() -> Result<()> {
|
|||||||
"discovery complete"
|
"discovery complete"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Build harness registry from config.
|
// Build harness registry from config. In-process harnesses (candle)
|
||||||
let registry = HarnessRegistry::from_configs(&cfg.harnesses);
|
// need to know neuron's own bind URL so they can return it from
|
||||||
|
// inference_endpoint.
|
||||||
|
let registry = HarnessRegistry::from_configs(&cfg.harnesses, &bind_url, &cfg.harness);
|
||||||
discovery_result.harnesses = registry.names();
|
discovery_result.harnesses = registry.names();
|
||||||
|
let candle = registry.candle();
|
||||||
|
|
||||||
|
// Activation: load default models before binding the listener.
|
||||||
|
// Each load may take tens of seconds to several minutes depending
|
||||||
|
// on model size and HF cache state — keep TimeoutStartSec in the
|
||||||
|
// systemd unit generous enough to cover the slowest entry.
|
||||||
|
startup::load_default_models(®istry, &cfg.default_models).await;
|
||||||
|
|
||||||
let health_cache = Arc::new(health::HealthCache::new());
|
let health_cache = Arc::new(health::HealthCache::new());
|
||||||
health_cache
|
health_cache
|
||||||
@@ -65,13 +120,24 @@ async fn main() -> Result<()> {
|
|||||||
discovery: discovery_result,
|
discovery: discovery_result,
|
||||||
health_cache,
|
health_cache,
|
||||||
registry: RwLock::new(registry),
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(Arc::clone(&state));
|
||||||
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
|
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
|
||||||
tracing::info!("neuron listening on {addr}");
|
tracing::info!("neuron listening on {addr}");
|
||||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
axum::serve(listener, app).await?;
|
axum::serve(listener, app)
|
||||||
|
.with_graceful_shutdown(startup::shutdown_signal())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Deactivation: serve has returned (graceful shutdown signal
|
||||||
|
// received and connections drained). Release CUDA contexts / VRAM
|
||||||
|
// by unloading every model before exiting; systemd's TimeoutStopSec
|
||||||
|
// bounds how long this phase may take.
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
startup::unload_all_models(®istry).await;
|
||||||
|
tracing::info!("shutdown complete");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
97
crates/neuron/src/startup.rs
Normal file
97
crates/neuron/src/startup.rs
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
//! Activation- and deactivation-time orchestration.
|
||||||
|
//!
|
||||||
|
//! Wired from `main.rs` around the HTTP listener — activation runs
|
||||||
|
//! before bind, deactivation runs after axum returns from its
|
||||||
|
//! graceful-shutdown future. Kept in its own module so the logic is
|
||||||
|
//! unit-testable without spinning up a full neuron process.
|
||||||
|
|
||||||
|
use crate::harness::HarnessRegistry;
|
||||||
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tokio::signal;
|
||||||
|
|
||||||
|
/// Load each spec sequentially against the registry, treating
|
||||||
|
/// individual failures as warnings rather than fatal errors.
|
||||||
|
///
|
||||||
|
/// VRAM contention makes parallel loads risky; the sequential path is
|
||||||
|
/// boring but correct. The function logs elapsed time per load so an
|
||||||
|
/// operator can see which model is hogging activation.
|
||||||
|
pub async fn load_default_models(registry: &HarnessRegistry, specs: &[ModelSpec]) {
|
||||||
|
if specs.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tracing::info!(count = specs.len(), "loading default models");
|
||||||
|
for spec in specs {
|
||||||
|
let start = Instant::now();
|
||||||
|
match registry.load_model(spec).await {
|
||||||
|
Ok(()) => tracing::info!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"loaded default model"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::warn!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
error = %e,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"failed to load default model, continuing"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Future that resolves on SIGINT (Ctrl-C) or SIGTERM (systemd stop).
|
||||||
|
///
|
||||||
|
/// Wired into `axum::serve(...).with_graceful_shutdown(shutdown_signal())`
|
||||||
|
/// so the HTTP listener stops accepting new connections, lets in-flight
|
||||||
|
/// requests drain, and then yields control back to main for cleanup.
|
||||||
|
pub async fn shutdown_signal() {
|
||||||
|
let ctrl_c = async {
|
||||||
|
signal::ctrl_c().await.ok();
|
||||||
|
};
|
||||||
|
let terminate = async {
|
||||||
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||||
|
.expect("install SIGTERM handler")
|
||||||
|
.recv()
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
tokio::select! {
|
||||||
|
_ = ctrl_c => tracing::info!("received SIGINT, shutting down"),
|
||||||
|
_ = terminate => tracing::info!("received SIGTERM, shutting down"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unload every model currently registered. Called from `main.rs` after
|
||||||
|
/// axum's graceful shutdown future resolves, so CUDA contexts and VRAM
|
||||||
|
/// are released before the process exits rather than left to the OS to
|
||||||
|
/// reclaim. Per-model failures are logged and skipped — keep cleanup
|
||||||
|
/// going even when one harness is unhealthy.
|
||||||
|
pub async fn unload_all_models(registry: &HarnessRegistry) {
|
||||||
|
let listed = match registry.list_all_models().await {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "failed to list models during shutdown");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if listed.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(count = listed.len(), "unloading models for shutdown");
|
||||||
|
for model in listed {
|
||||||
|
let start = Instant::now();
|
||||||
|
match registry.unload_model(&model.id).await {
|
||||||
|
Ok(()) => tracing::info!(
|
||||||
|
model = %model.id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"unloaded"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::warn!(
|
||||||
|
model = %model.id,
|
||||||
|
error = %e,
|
||||||
|
"unload failed during shutdown"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
56
crates/neuron/tests/activation.rs
Normal file
56
crates/neuron/tests/activation.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
//! Activation-time behaviour: load_default_models continues past
|
||||||
|
//! individual failures so a single broken catalogue entry doesn't
|
||||||
|
//! prevent the rest of the fleet from starting.
|
||||||
|
|
||||||
|
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
use neuron::harness::HarnessRegistry;
|
||||||
|
use neuron::startup;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_load_default_models_skips_unknown_harness() {
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Both entries fail synchronously inside the registry — no network
|
||||||
|
// call escapes (the harness lookup mismatches before hf-hub is
|
||||||
|
// touched). The function should still return cleanly.
|
||||||
|
let specs = vec![
|
||||||
|
ModelSpec {
|
||||||
|
model_id: "model-a".into(),
|
||||||
|
harness: "no-such-harness".into(),
|
||||||
|
quant: None,
|
||||||
|
tensor_parallel: None,
|
||||||
|
devices: None,
|
||||||
|
},
|
||||||
|
ModelSpec {
|
||||||
|
model_id: "model-b".into(),
|
||||||
|
harness: "no-such-harness".into(),
|
||||||
|
quant: None,
|
||||||
|
tensor_parallel: None,
|
||||||
|
devices: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
startup::load_default_models(®istry, &specs).await;
|
||||||
|
|
||||||
|
let listed = registry
|
||||||
|
.list_all_models()
|
||||||
|
.await
|
||||||
|
.expect("list_all_models should succeed");
|
||||||
|
assert!(
|
||||||
|
listed.is_empty(),
|
||||||
|
"no models should be loaded after failed entries"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_load_default_models_empty_is_noop() {
|
||||||
|
let registry = HarnessRegistry::new();
|
||||||
|
startup::load_default_models(®istry, &[]).await;
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ async fn spawn_neuron(discovery: DiscoveryResponse) -> String {
|
|||||||
discovery,
|
discovery,
|
||||||
health_cache,
|
health_cache,
|
||||||
registry: RwLock::new(registry),
|
registry: RwLock::new(registry),
|
||||||
|
candle: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(state);
|
||||||
@@ -135,56 +136,30 @@ async fn test_models_empty_registry() {
|
|||||||
assert!(body.as_array().unwrap().is_empty());
|
assert!(body.as_array().unwrap().is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spawn a mock mistral.rs backend and a neuron with the mistralrs harness
|
/// Verify the candle harness registers, list is empty by default, and a
|
||||||
/// pointing at it, then test the full model lifecycle through neuron's API.
|
/// load attempt for an obviously-bogus model id returns a 4xx error
|
||||||
|
/// without crashing the daemon. Real load/unload exercising actual GGUF
|
||||||
|
/// download is covered by `tests/candle_lifecycle.rs` (cuda-integration).
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_models_via_mistralrs_harness() {
|
async fn test_candle_harness_registers_and_rejects_bogus_model() {
|
||||||
use axum::routing::{get, post};
|
|
||||||
use axum::{Json, Router};
|
|
||||||
use cortex_core::harness::HarnessConfig;
|
use cortex_core::harness::HarnessConfig;
|
||||||
use serde_json::Value;
|
use neuron::config::HarnessSettings;
|
||||||
|
|
||||||
// Mock mistral.rs backend.
|
let registry = HarnessRegistry::from_configs(
|
||||||
let mock_app = Router::new()
|
&[HarnessConfig {
|
||||||
.route(
|
name: "candle".into(),
|
||||||
"/v1/models",
|
}],
|
||||||
get(|| async {
|
"http://localhost:13131",
|
||||||
Json(json!({
|
&HarnessSettings::default(),
|
||||||
"data": [
|
);
|
||||||
{"id": "test-model", "status": "loaded"},
|
|
||||||
{"id": "other-model", "status": "unloaded"}
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/v1/models/unload",
|
|
||||||
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/v1/models/reload",
|
|
||||||
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mock_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let mock_addr = mock_listener.local_addr().unwrap();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
axum::serve(mock_listener, mock_app).await.unwrap();
|
|
||||||
});
|
|
||||||
let mock_url = format!("http://{mock_addr}");
|
|
||||||
|
|
||||||
// Build neuron with mistralrs harness pointing at mock.
|
|
||||||
let registry = HarnessRegistry::from_configs(&[HarnessConfig {
|
|
||||||
name: "mistralrs".into(),
|
|
||||||
endpoint: Some(mock_url.clone()),
|
|
||||||
systemd_unit: None,
|
|
||||||
}]);
|
|
||||||
|
|
||||||
|
let candle = registry.candle();
|
||||||
let health_cache = Arc::new(HealthCache::new());
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
let state = Arc::new(NeuronState {
|
let state = Arc::new(NeuronState {
|
||||||
discovery: fake_discovery(),
|
discovery: fake_discovery(),
|
||||||
health_cache,
|
health_cache,
|
||||||
registry: RwLock::new(registry),
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
});
|
});
|
||||||
|
|
||||||
let app = api::neuron_routes().with_state(state);
|
let app = api::neuron_routes().with_state(state);
|
||||||
@@ -197,7 +172,6 @@ async fn test_models_via_mistralrs_harness() {
|
|||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
// GET /models — should return models from mock mistralrs.
|
|
||||||
let resp = client
|
let resp = client
|
||||||
.get(format!("{neuron_url}/models"))
|
.get(format!("{neuron_url}/models"))
|
||||||
.send()
|
.send()
|
||||||
@@ -205,45 +179,140 @@ async fn test_models_via_mistralrs_harness() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||||
assert_eq!(models.len(), 2);
|
assert!(models.is_empty());
|
||||||
assert_eq!(models[0]["id"], "test-model");
|
|
||||||
assert_eq!(models[0]["harness"], "mistralrs");
|
|
||||||
assert_eq!(models[0]["status"], "loaded");
|
|
||||||
assert_eq!(models[1]["id"], "other-model");
|
|
||||||
assert_eq!(models[1]["status"], "unloaded");
|
|
||||||
|
|
||||||
// GET /models/test-model/endpoint — should return mock URL.
|
// Sending a wrong-harness spec should be rejected synchronously
|
||||||
let resp = client
|
// without touching the network or the model registry.
|
||||||
.get(format!("{neuron_url}/models/test-model/endpoint"))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(resp.status(), 200);
|
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
|
||||||
assert_eq!(body["url"], mock_url);
|
|
||||||
|
|
||||||
// POST /models/unload — should succeed.
|
|
||||||
let resp = client
|
|
||||||
.post(format!("{neuron_url}/models/unload"))
|
|
||||||
.json(&json!({"model_id": "test-model"}))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(resp.status(), 200);
|
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
|
||||||
assert_eq!(body["status"], "unloaded");
|
|
||||||
|
|
||||||
// POST /models/load — should succeed.
|
|
||||||
let resp = client
|
let resp = client
|
||||||
.post(format!("{neuron_url}/models/load"))
|
.post(format!("{neuron_url}/models/load"))
|
||||||
|
.json(&json!({"model_id": "definitely/not-real", "harness": "not-candle"}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 400);
|
||||||
|
|
||||||
|
// Registry still empty.
|
||||||
|
let resp = client
|
||||||
|
.get(format!("{neuron_url}/models"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||||
|
assert!(models.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `/v1/chat/completions` returns 503 when no candle harness is registered.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completions_no_candle_harness() {
|
||||||
|
let registry = HarnessRegistry::new();
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle: None,
|
||||||
|
});
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{url}/v1/chat/completions"))
|
||||||
.json(&json!({
|
.json(&json!({
|
||||||
"model_id": "test-model",
|
"model": "anything",
|
||||||
"harness": "mistralrs"
|
"messages": [{"role": "user", "content": "hi"}]
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 503);
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
}
|
||||||
assert_eq!(body["status"], "loaded");
|
|
||||||
|
/// `/v1/chat/completions` returns 404 when the requested model isn't loaded.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completions_model_not_loaded() {
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
let candle = registry.candle();
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
|
});
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{url}/v1/chat/completions"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "definitely/not-loaded",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}]
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 404);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `/v1/chat/completions` with `stream: true` returns 404 when the
|
||||||
|
/// model isn't loaded — same surface as the non-streaming path. The
|
||||||
|
/// streaming code only kicks in once the model lookup succeeds.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_completions_streaming_model_not_loaded() {
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
let candle = registry.candle();
|
||||||
|
let health_cache = Arc::new(HealthCache::new());
|
||||||
|
let state = Arc::new(NeuronState {
|
||||||
|
discovery: fake_discovery(),
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
|
});
|
||||||
|
let app = api::neuron_routes().with_state(state);
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let url = format!("http://{addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{url}/v1/chat/completions"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "definitely/not-loaded",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"stream": true
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 404);
|
||||||
}
|
}
|
||||||
|
|||||||
87
crates/neuron/tests/candle_lifecycle.rs
Normal file
87
crates/neuron/tests/candle_lifecycle.rs
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
//! Real model load/unload lifecycle through the candle harness.
|
||||||
|
//!
|
||||||
|
//! Gated behind the `cuda-integration` feature because it downloads a
|
||||||
|
//! real (small) GGUF from HuggingFace and materialises tensors on the
|
||||||
|
//! configured device. Run on a host with network access and either a
|
||||||
|
//! CUDA GPU (when built with `--features cuda`) or enough CPU RAM to
|
||||||
|
//! hold the model.
|
||||||
|
//!
|
||||||
|
//! Usage:
|
||||||
|
//! cargo test -p neuron --features cuda-integration --test candle_lifecycle
|
||||||
|
//!
|
||||||
|
//! Optional environment variables:
|
||||||
|
//! NEURON_TEST_MODEL_ID — HuggingFace repo to load (default: a small
|
||||||
|
//! public Qwen3 GGUF repo).
|
||||||
|
//! NEURON_TEST_QUANT — quant substring matched against GGUF
|
||||||
|
//! filenames (default: "Q4_K_M").
|
||||||
|
//! HF_HOME — HuggingFace cache directory.
|
||||||
|
|
||||||
|
#![cfg(feature = "cuda-integration")]
|
||||||
|
|
||||||
|
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
use neuron::harness::HarnessRegistry;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_candle_qwen3_load_unload_lifecycle() {
|
||||||
|
let _ = tracing_subscriber::fmt()
|
||||||
|
.with_test_writer()
|
||||||
|
.with_env_filter("info,neuron=debug")
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
let model_id = std::env::var("NEURON_TEST_MODEL_ID")
|
||||||
|
.unwrap_or_else(|_| "Qwen/Qwen3-0.6B-GGUF".to_string());
|
||||||
|
let quant = std::env::var("NEURON_TEST_QUANT").unwrap_or_else(|_| "Q4_K_M".to_string());
|
||||||
|
|
||||||
|
let mut settings = HarnessSettings::default();
|
||||||
|
if let Ok(home) = std::env::var("HF_HOME") {
|
||||||
|
settings.candle.hf_cache = Some(PathBuf::from(home));
|
||||||
|
}
|
||||||
|
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:13131",
|
||||||
|
&settings,
|
||||||
|
);
|
||||||
|
|
||||||
|
let spec = ModelSpec {
|
||||||
|
model_id: model_id.clone(),
|
||||||
|
harness: "candle".into(),
|
||||||
|
quant: Some(quant),
|
||||||
|
tensor_parallel: None,
|
||||||
|
devices: Some(vec![0]),
|
||||||
|
};
|
||||||
|
|
||||||
|
registry
|
||||||
|
.load_model(&spec)
|
||||||
|
.await
|
||||||
|
.expect("load_model should succeed");
|
||||||
|
|
||||||
|
let models = registry.list_all_models().await.expect("list_all_models");
|
||||||
|
assert_eq!(models.len(), 1, "expected exactly one loaded model");
|
||||||
|
assert_eq!(models[0].id, model_id);
|
||||||
|
assert_eq!(models[0].harness, "candle");
|
||||||
|
assert_eq!(models[0].status, "loaded");
|
||||||
|
|
||||||
|
let url = registry.inference_endpoint(&model_id).await;
|
||||||
|
assert_eq!(url, Some("http://localhost:13131".into()));
|
||||||
|
|
||||||
|
// Re-loading the same model should be rejected.
|
||||||
|
let again = registry.load_model(&spec).await;
|
||||||
|
assert!(again.is_err(), "second load should error");
|
||||||
|
|
||||||
|
registry
|
||||||
|
.unload_model(&model_id)
|
||||||
|
.await
|
||||||
|
.expect("unload_model should succeed");
|
||||||
|
|
||||||
|
let models = registry.list_all_models().await.expect("list_all_models");
|
||||||
|
assert!(models.is_empty(), "registry should be empty after unload");
|
||||||
|
|
||||||
|
// Unloading a model that isn't loaded should error.
|
||||||
|
let err = registry.unload_model(&model_id).await;
|
||||||
|
assert!(err.is_err(), "unload of missing model should error");
|
||||||
|
}
|
||||||
32
crates/neuron/tests/shutdown.rs
Normal file
32
crates/neuron/tests/shutdown.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
//! Deactivation behaviour: unload_all_models tolerates an empty
|
||||||
|
//! registry and continues past per-model unload failures.
|
||||||
|
|
||||||
|
use cortex_core::harness::HarnessConfig;
|
||||||
|
use neuron::config::HarnessSettings;
|
||||||
|
use neuron::harness::HarnessRegistry;
|
||||||
|
use neuron::startup;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unload_all_models_empty_registry_is_noop() {
|
||||||
|
let registry = HarnessRegistry::new();
|
||||||
|
startup::unload_all_models(®istry).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unload_all_models_with_no_loaded_models() {
|
||||||
|
let registry = HarnessRegistry::from_configs(
|
||||||
|
&[HarnessConfig {
|
||||||
|
name: "candle".into(),
|
||||||
|
}],
|
||||||
|
"http://localhost:0",
|
||||||
|
&HarnessSettings::default(),
|
||||||
|
);
|
||||||
|
|
||||||
|
startup::unload_all_models(®istry).await;
|
||||||
|
|
||||||
|
let listed = registry
|
||||||
|
.list_all_models()
|
||||||
|
.await
|
||||||
|
.expect("list_all_models should still succeed after shutdown cleanup");
|
||||||
|
assert!(listed.is_empty());
|
||||||
|
}
|
||||||
145
crates/neuron/tests/tp_worker_lifecycle.rs
Normal file
145
crates/neuron/tests/tp_worker_lifecycle.rs
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
//! Stage 7a-i: confirm the TP worker subprocess lifecycle round-trips.
|
||||||
|
//!
|
||||||
|
//! Spawns two worker subprocesses via the leader→worker stdio RPC,
|
||||||
|
//! pings each, and cleanly shuts them down. No CUDA required —
|
||||||
|
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
|
||||||
|
//! runs on any host the workspace builds on.
|
||||||
|
|
||||||
|
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
|
||||||
|
|
||||||
|
/// Path to the neuron binary built by cargo for this test process.
|
||||||
|
/// cargo populates `CARGO_BIN_EXE_neuron` at compile time for sibling-
|
||||||
|
/// binary tests; production paths in main.rs use `/proc/self/exe`.
|
||||||
|
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
|
||||||
|
|
||||||
|
/// Two workers (so we spawn one subprocess: rank 0 is in-process,
|
||||||
|
/// rank 1 is the child). Verify the spawned worker responds to Ping
|
||||||
|
/// with its own identity, then shut it down cleanly.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_spawn_ping_shutdown() {
|
||||||
|
// cuda_devices: rank 0 → device 0 (leader, unused here),
|
||||||
|
// rank 1 → device 1 (worker; not actually opened in 7a-i).
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
|
||||||
|
.await
|
||||||
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
let pongs = pool.ping_all().await.expect("ping all workers");
|
||||||
|
assert_eq!(pongs.len(), 1, "expected one Pong (rank 1 only)");
|
||||||
|
match &pongs[0] {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
assert_eq!(*rank, 1);
|
||||||
|
assert_eq!(*world_size, 2);
|
||||||
|
assert_eq!(*cuda_device, 1);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.shutdown().await.expect("clean shutdown");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Three workers — exercise the loop in `ping_all` / `shutdown`.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_spawn_three_workers() {
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2])
|
||||||
|
.await
|
||||||
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
let pongs = pool.ping_all().await.expect("ping all workers");
|
||||||
|
assert_eq!(pongs.len(), 2, "expected two Pongs (ranks 1 and 2)");
|
||||||
|
for (i, resp) in pongs.iter().enumerate() {
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
let expected_rank = (i + 1) as u32;
|
||||||
|
assert_eq!(*rank, expected_rank);
|
||||||
|
assert_eq!(*world_size, 3);
|
||||||
|
assert_eq!(*cuda_device, expected_rank);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.shutdown().await.expect("clean shutdown");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 7a-ii: without the cuda feature, Init must fail with a clear
|
||||||
|
/// `cuda_feature_not_enabled` marker rather than silently succeeding.
|
||||||
|
/// This is the local-dev-box test; the real NCCL handshake is exercised
|
||||||
|
/// by `tp_worker_lifecycle_cuda.rs` (gated on `cuda-integration`).
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_init_returns_cuda_feature_not_enabled_without_cuda() {
|
||||||
|
use neuron::harness::tp::rpc::WorkerRequest;
|
||||||
|
use std::process::Stdio;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
use tokio::process::Command;
|
||||||
|
|
||||||
|
// Spawn a single worker by hand to send Init directly (the pool's
|
||||||
|
// public API doesn't expose Init yet — that lands in 7a-ii).
|
||||||
|
let mut child = Command::new(NEURON_BIN)
|
||||||
|
.arg("--worker")
|
||||||
|
.arg("--rank")
|
||||||
|
.arg("1")
|
||||||
|
.arg("--tp-size")
|
||||||
|
.arg("2")
|
||||||
|
.arg("--cuda-device")
|
||||||
|
.arg("1")
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::null())
|
||||||
|
.kill_on_drop(true)
|
||||||
|
.spawn()
|
||||||
|
.expect("spawn worker");
|
||||||
|
|
||||||
|
let mut stdin = child.stdin.take().expect("stdin");
|
||||||
|
let stdout = child.stdout.take().expect("stdout");
|
||||||
|
let mut lines = BufReader::new(stdout).lines();
|
||||||
|
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: "ff".repeat(128),
|
||||||
|
};
|
||||||
|
let mut payload = serde_json::to_string(&req).unwrap();
|
||||||
|
payload.push('\n');
|
||||||
|
stdin.write_all(payload.as_bytes()).await.unwrap();
|
||||||
|
stdin.flush().await.unwrap();
|
||||||
|
|
||||||
|
let reply = lines
|
||||||
|
.next_line()
|
||||||
|
.await
|
||||||
|
.expect("read line")
|
||||||
|
.expect("got line");
|
||||||
|
let resp: WorkerResponse = serde_json::from_str(&reply).expect("parse reply");
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::Error { kind, .. } => {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
// With cuda enabled the response depends on whether
|
||||||
|
// CUDA hardware is actually present. Accept either
|
||||||
|
// the success contract or a real NCCL failure.
|
||||||
|
let _ = kind;
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
assert_eq!(kind, "cuda_feature_not_enabled");
|
||||||
|
}
|
||||||
|
WorkerResponse::InitOk => {
|
||||||
|
// Real NCCL succeeded — only possible with cuda feature
|
||||||
|
// AND a working NCCL stack AND another rank actually
|
||||||
|
// joining. Don't fail; just acknowledge.
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
panic!("InitOk without cuda feature is impossible");
|
||||||
|
}
|
||||||
|
other => panic!("expected Error or InitOk, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean shutdown.
|
||||||
|
stdin.write_all(b"{\"op\":\"shutdown\"}\n").await.unwrap();
|
||||||
|
stdin.flush().await.unwrap();
|
||||||
|
let _ = lines.next_line().await; // Bye
|
||||||
|
let _ = child.wait().await;
|
||||||
|
}
|
||||||
43
crates/neuron/tests/tp_worker_lifecycle_cuda.rs
Normal file
43
crates/neuron/tests/tp_worker_lifecycle_cuda.rs
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
//! Stage 7a-ii: real NCCL handshake across the worker pool.
|
||||||
|
//!
|
||||||
|
//! Gated behind the `cuda-integration` feature because it requires
|
||||||
|
//! libnccl AND multiple CUDA devices on the running host. Run on
|
||||||
|
//! beast (2× RTX 5090) via:
|
||||||
|
//!
|
||||||
|
//! cargo test -p neuron --features cuda-integration \
|
||||||
|
//! --test tp_worker_lifecycle_cuda
|
||||||
|
//!
|
||||||
|
//! Steps: spawn N-1 workers, call `init_nccl`, run `nccl_sanity_check`
|
||||||
|
//! (every rank `all_reduce`s `1u32` with Sum; expected total =
|
||||||
|
//! world_size), shut down cleanly.
|
||||||
|
|
||||||
|
#![cfg(feature = "cuda-integration")]
|
||||||
|
|
||||||
|
use neuron::harness::tp::WorkerPool;
|
||||||
|
|
||||||
|
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_init_and_sanity_check_two_ranks() {
|
||||||
|
let _ = tracing_subscriber::fmt()
|
||||||
|
.with_test_writer()
|
||||||
|
.with_env_filter("info,neuron=debug")
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
// 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1.
|
||||||
|
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1])
|
||||||
|
.await
|
||||||
|
.expect("spawn worker pool");
|
||||||
|
|
||||||
|
pool.ping_all().await.expect("pong all workers");
|
||||||
|
|
||||||
|
pool.init_nccl(0)
|
||||||
|
.await
|
||||||
|
.expect("init_nccl: NCCL handshake across all ranks");
|
||||||
|
|
||||||
|
pool.nccl_sanity_check()
|
||||||
|
.await
|
||||||
|
.expect("nccl_sanity_check: observed_sum == world_size on all ranks");
|
||||||
|
|
||||||
|
pool.shutdown().await.expect("clean shutdown");
|
||||||
|
}
|
||||||
7
data/cortex-firewalld.xml
Normal file
7
data/cortex-firewalld.xml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<service>
|
||||||
|
<short>cortex</short>
|
||||||
|
<description>Cortex — inference gateway for multi-node GPU clusters</description>
|
||||||
|
<port protocol="tcp" port="31313"/>
|
||||||
|
<port protocol="tcp" port="31314"/>
|
||||||
|
</service>
|
||||||
6
data/neuron-firewalld.xml
Normal file
6
data/neuron-firewalld.xml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<service>
|
||||||
|
<short>helexa-neuron</short>
|
||||||
|
<description>Neuron — per-node GPU discovery and harness daemon for cortex</description>
|
||||||
|
<port protocol="tcp" port="13131"/>
|
||||||
|
</service>
|
||||||
@@ -10,6 +10,22 @@ Restart=on-failure
|
|||||||
RestartSec=5
|
RestartSec=5
|
||||||
User=neuron
|
User=neuron
|
||||||
Group=neuron
|
Group=neuron
|
||||||
|
# /var/lib/neuron is the neuron user's $HOME — hf-hub writes its
|
||||||
|
# default cache there (~/.cache/huggingface/hub). Without this directive
|
||||||
|
# systemd doesn't create the directory and hf-hub downloads fail with
|
||||||
|
# "fetch GGUF <file>: failed to create cache dir".
|
||||||
|
StateDirectory=neuron
|
||||||
|
StateDirectoryMode=0755
|
||||||
|
# Loading default_models from neuron.toml happens before the HTTP
|
||||||
|
# listener binds; large models can take many minutes to download and
|
||||||
|
# materialise on first activation. systemd's default TimeoutStartSec
|
||||||
|
# (90s) is far too short; allow 30 minutes.
|
||||||
|
TimeoutStartSec=1800s
|
||||||
|
# On stop, neuron drains in-flight requests then unloads every model
|
||||||
|
# to release CUDA contexts cleanly. Allow generous time for big-model
|
||||||
|
# unloads; systemd will SIGKILL after this bound.
|
||||||
|
TimeoutStopSec=120s
|
||||||
|
KillSignal=SIGTERM
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
|
|||||||
101
helexa-neuron.spec
Normal file
101
helexa-neuron.spec
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
Name: helexa-neuron
|
||||||
|
Version: 0.1.16
|
||||||
|
Release: 1%{?dist}
|
||||||
|
Summary: Per-node GPU discovery and harness management daemon for cortex
|
||||||
|
# Package name disambiguates from Fedora's existing "neuron" package
|
||||||
|
# (NEURON neural simulation environment from Yale). Binary, systemd
|
||||||
|
# unit, and system user are still called "neuron" for brevity.
|
||||||
|
|
||||||
|
License: GPL-3.0-or-later
|
||||||
|
URL: https://git.lair.cafe/helexa/cortex
|
||||||
|
Source0: %{name}-%{version}.tar.gz
|
||||||
|
Source1: %{name}-%{version}-vendor.tar.gz
|
||||||
|
|
||||||
|
ExclusiveArch: x86_64
|
||||||
|
|
||||||
|
BuildRequires: rust >= 1.85
|
||||||
|
BuildRequires: cargo
|
||||||
|
BuildRequires: gcc
|
||||||
|
BuildRequires: gcc-c++
|
||||||
|
BuildRequires: cmake
|
||||||
|
BuildRequires: perl-interpreter
|
||||||
|
BuildRequires: pkgconfig(openssl)
|
||||||
|
BuildRequires: systemd-rpm-macros
|
||||||
|
|
||||||
|
Requires(pre): shadow-utils
|
||||||
|
Requires: systemd
|
||||||
|
Requires: firewalld-filesystem
|
||||||
|
|
||||||
|
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
|
||||||
|
# from our .service file and emits Requires: user(neuron)/group(neuron).
|
||||||
|
# rpm's sysusers provides-generator emits the unversioned form for groups
|
||||||
|
# but only a versioned user(neuron) = <base64> for users with GECOS/home/
|
||||||
|
# shell. Provide the unversioned user(neuron) explicitly so dnf can resolve
|
||||||
|
# the auto-generated Requires. Without this, dnf5 silently filters the
|
||||||
|
# package and reports "Nothing to do".
|
||||||
|
Provides: user(neuron)
|
||||||
|
|
||||||
|
%description
|
||||||
|
Neuron is a per-node daemon for cortex inference clusters. It discovers
|
||||||
|
local GPU hardware via nvidia-smi, runs in-process inference via
|
||||||
|
huggingface/candle, and exposes an HTTP API for model lifecycle
|
||||||
|
management (load, unload, list, inference endpoint).
|
||||||
|
|
||||||
|
%prep
|
||||||
|
%autosetup
|
||||||
|
tar xf %{SOURCE1}
|
||||||
|
mkdir -p .cargo
|
||||||
|
cat > .cargo/config.toml << 'EOF'
|
||||||
|
[source.crates-io]
|
||||||
|
replace-with = "vendored-sources"
|
||||||
|
|
||||||
|
[source.vendored-sources]
|
||||||
|
directory = "vendor"
|
||||||
|
EOF
|
||||||
|
|
||||||
|
%build
|
||||||
|
cargo build --release -p neuron
|
||||||
|
|
||||||
|
%install
|
||||||
|
install -Dm755 target/release/neuron %{buildroot}%{_bindir}/neuron
|
||||||
|
install -Dm644 data/neuron.service %{buildroot}%{_unitdir}/neuron.service
|
||||||
|
install -Dm644 data/neuron-sysusers.conf %{buildroot}%{_sysusersdir}/neuron.conf
|
||||||
|
install -Dm644 data/neuron-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/helexa-neuron.xml
|
||||||
|
install -dm755 %{buildroot}%{_sysconfdir}/neuron
|
||||||
|
install -Dm644 neuron.example.toml %{buildroot}%{_sysconfdir}/neuron/neuron.toml
|
||||||
|
|
||||||
|
%pre
|
||||||
|
%sysusers_create_compat %{_builddir}/%{name}-%{version}/data/neuron-sysusers.conf
|
||||||
|
|
||||||
|
%post
|
||||||
|
%systemd_post neuron.service
|
||||||
|
|
||||||
|
%preun
|
||||||
|
%systemd_preun neuron.service
|
||||||
|
|
||||||
|
%postun
|
||||||
|
%systemd_postun_with_restart neuron.service
|
||||||
|
|
||||||
|
%files
|
||||||
|
%license LICENSE
|
||||||
|
%doc README.md
|
||||||
|
%{_bindir}/neuron
|
||||||
|
%{_unitdir}/neuron.service
|
||||||
|
%{_sysusersdir}/neuron.conf
|
||||||
|
%{_prefix}/lib/firewalld/services/helexa-neuron.xml
|
||||||
|
%dir %{_sysconfdir}/neuron
|
||||||
|
%config(noreplace) %{_sysconfdir}/neuron/neuron.toml
|
||||||
|
|
||||||
|
%changelog
|
||||||
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
|
||||||
|
- chore: ignore local deploy script
|
||||||
|
- chore: move default ports out of common-collision ranges
|
||||||
|
- ci: drop actions/cache for cargo registry and target
|
||||||
|
|
||||||
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
|
||||||
|
- ci: publish both packages to a single helexa/helexa COPR project
|
||||||
|
- fix(rpm): rename neuron package to helexa-neuron
|
||||||
|
- ci: commit generated %changelog entries back to main
|
||||||
|
|
||||||
|
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
||||||
|
- Initial package
|
||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
[[models]]
|
[[models]]
|
||||||
id = "your-org/large-model"
|
id = "your-org/large-model"
|
||||||
harness = "mistralrs"
|
harness = "candle"
|
||||||
quant = "Q4_K_M"
|
quant = "Q4_K_M"
|
||||||
vram_mb = 19000
|
vram_mb = 19000
|
||||||
min_devices = 2
|
min_devices = 2
|
||||||
@@ -15,7 +15,7 @@ pinned_on = ["gpu-large"]
|
|||||||
|
|
||||||
[[models]]
|
[[models]]
|
||||||
id = "your-org/medium-model"
|
id = "your-org/medium-model"
|
||||||
harness = "mistralrs"
|
harness = "candle"
|
||||||
quant = "Q6_K"
|
quant = "Q6_K"
|
||||||
vram_mb = 12000
|
vram_mb = 12000
|
||||||
min_devices = 1
|
min_devices = 1
|
||||||
@@ -23,7 +23,7 @@ pinned_on = ["gpu-medium"]
|
|||||||
|
|
||||||
[[models]]
|
[[models]]
|
||||||
id = "your-org/embedding-model"
|
id = "your-org/embedding-model"
|
||||||
harness = "mistralrs"
|
harness = "candle"
|
||||||
quant = "Q8_0"
|
quant = "Q8_0"
|
||||||
vram_mb = 8000
|
vram_mb = 8000
|
||||||
min_devices = 1
|
min_devices = 1
|
||||||
|
|||||||
@@ -3,14 +3,38 @@
|
|||||||
# Copy to /etc/neuron/neuron.toml and adjust for your environment.
|
# Copy to /etc/neuron/neuron.toml and adjust for your environment.
|
||||||
#
|
#
|
||||||
# Environment variable overrides use NEURON_ prefix with __ separators:
|
# Environment variable overrides use NEURON_ prefix with __ separators:
|
||||||
# NEURON_PORT=9090
|
# NEURON_PORT=13131
|
||||||
|
|
||||||
port = 9090
|
port = 13131
|
||||||
|
|
||||||
# -- Harnesses ---------------------------------------------------------------
|
# -- Harnesses ---------------------------------------------------------------
|
||||||
# Each [[harnesses]] entry declares an inference engine managed by neuron.
|
# Each [[harnesses]] entry enables an inference engine. Currently only
|
||||||
|
# "candle" is supported — it runs in-process and uses huggingface/candle
|
||||||
|
# for inference on local CUDA devices (or CPU when CUDA is unavailable).
|
||||||
|
|
||||||
[[harnesses]]
|
[[harnesses]]
|
||||||
name = "mistralrs"
|
name = "candle"
|
||||||
endpoint = "http://localhost:8080"
|
|
||||||
systemd_unit = "mistralrs.service"
|
# -- Candle harness settings -------------------------------------------------
|
||||||
|
# Optional tuning for the candle harness.
|
||||||
|
|
||||||
|
[harness.candle]
|
||||||
|
# HuggingFace cache directory for model weights. When unset, hf-hub's
|
||||||
|
# default (~/.cache/huggingface) is used.
|
||||||
|
# hf_cache = "/var/lib/neuron/hf-cache"
|
||||||
|
|
||||||
|
# -- Default models ----------------------------------------------------------
|
||||||
|
# Models listed here are loaded automatically when the neuron service
|
||||||
|
# activates. Loading is sequential — a slow or failing entry doesn't
|
||||||
|
# block the rest of the fleet, but it does push out the time before
|
||||||
|
# neuron starts serving HTTP, so keep the list short. Operators can
|
||||||
|
# load additional models on demand via POST /models/load.
|
||||||
|
#
|
||||||
|
# Make sure data/neuron.service's TimeoutStartSec is generous enough to
|
||||||
|
# cover the slowest entry's first-time download + materialisation.
|
||||||
|
|
||||||
|
# [[default_models]]
|
||||||
|
# model_id = "Qwen/Qwen3-0.6B-GGUF"
|
||||||
|
# harness = "candle"
|
||||||
|
# quant = "Q4_K_M"
|
||||||
|
# devices = [0]
|
||||||
|
|||||||
81
neuron.spec
81
neuron.spec
@@ -1,81 +0,0 @@
|
|||||||
Name: neuron
|
|
||||||
Version: 0.1.7
|
|
||||||
Release: 1%{?dist}
|
|
||||||
Summary: Per-node GPU discovery and harness management daemon for cortex
|
|
||||||
|
|
||||||
License: GPL-3.0-or-later
|
|
||||||
URL: https://git.lair.cafe/helexa/cortex
|
|
||||||
Source0: %{name}-%{version}.tar.gz
|
|
||||||
Source1: %{name}-%{version}-vendor.tar.gz
|
|
||||||
|
|
||||||
ExclusiveArch: x86_64
|
|
||||||
|
|
||||||
BuildRequires: rust >= 1.85
|
|
||||||
BuildRequires: cargo
|
|
||||||
BuildRequires: gcc
|
|
||||||
BuildRequires: gcc-c++
|
|
||||||
BuildRequires: cmake
|
|
||||||
BuildRequires: perl-interpreter
|
|
||||||
BuildRequires: pkgconfig(openssl)
|
|
||||||
BuildRequires: systemd-rpm-macros
|
|
||||||
|
|
||||||
Requires(pre): shadow-utils
|
|
||||||
Requires: systemd
|
|
||||||
|
|
||||||
# rpm's sysusers provides-generator only emits versioned user(neuron) when
|
|
||||||
# the u-line has GECOS/home/shell fields. %attr(,,neuron) in %files emits
|
|
||||||
# an unversioned Requires: user(neuron), so we provide it explicitly.
|
|
||||||
Provides: user(neuron)
|
|
||||||
Provides: group(neuron)
|
|
||||||
|
|
||||||
%description
|
|
||||||
Neuron is a per-node daemon for cortex inference clusters. It discovers
|
|
||||||
local GPU hardware via nvidia-smi, manages inference harnesses (mistral.rs,
|
|
||||||
llama.cpp), and exposes an HTTP API for model lifecycle management.
|
|
||||||
|
|
||||||
%prep
|
|
||||||
%autosetup
|
|
||||||
tar xf %{SOURCE1}
|
|
||||||
mkdir -p .cargo
|
|
||||||
cat > .cargo/config.toml << 'EOF'
|
|
||||||
[source.crates-io]
|
|
||||||
replace-with = "vendored-sources"
|
|
||||||
|
|
||||||
[source.vendored-sources]
|
|
||||||
directory = "vendor"
|
|
||||||
EOF
|
|
||||||
|
|
||||||
%build
|
|
||||||
cargo build --release -p neuron
|
|
||||||
|
|
||||||
%install
|
|
||||||
install -Dm755 target/release/neuron %{buildroot}%{_bindir}/neuron
|
|
||||||
install -Dm644 data/neuron.service %{buildroot}%{_unitdir}/neuron.service
|
|
||||||
install -Dm644 data/neuron-sysusers.conf %{buildroot}%{_sysusersdir}/neuron.conf
|
|
||||||
install -dm750 %{buildroot}%{_sysconfdir}/neuron
|
|
||||||
install -Dm640 neuron.example.toml %{buildroot}%{_sysconfdir}/neuron/neuron.toml
|
|
||||||
|
|
||||||
%pre
|
|
||||||
%sysusers_create_compat %{_builddir}/%{name}-%{version}/data/neuron-sysusers.conf
|
|
||||||
|
|
||||||
%post
|
|
||||||
%systemd_post neuron.service
|
|
||||||
|
|
||||||
%preun
|
|
||||||
%systemd_preun neuron.service
|
|
||||||
|
|
||||||
%postun
|
|
||||||
%systemd_postun_with_restart neuron.service
|
|
||||||
|
|
||||||
%files
|
|
||||||
%license LICENSE
|
|
||||||
%doc README.md
|
|
||||||
%{_bindir}/neuron
|
|
||||||
%{_unitdir}/neuron.service
|
|
||||||
%{_sysusersdir}/neuron.conf
|
|
||||||
%dir %attr(750,root,neuron) %{_sysconfdir}/neuron
|
|
||||||
%config(noreplace) %attr(640,root,neuron) %{_sysconfdir}/neuron/neuron.toml
|
|
||||||
|
|
||||||
%changelog
|
|
||||||
* Tue Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
|
||||||
- Initial package
|
|
||||||
106
rpm/cortex-prerelease.spec
Normal file
106
rpm/cortex-prerelease.spec
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# Prebuilt-binary spec for cortex.
|
||||||
|
#
|
||||||
|
# Unlike cortex.spec (which builds from source via cargo), this spec
|
||||||
|
# wraps a pre-built `cortex` binary produced by an upstream CI job and
|
||||||
|
# packages it for rpm.lair.cafe. The %build phase is a no-op.
|
||||||
|
#
|
||||||
|
# Required defines at rpmbuild time:
|
||||||
|
# cortex_version e.g. "0.1.16"
|
||||||
|
# cortex_prerelease e.g. "0.1.20260518140530.gitabcdef0"
|
||||||
|
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
|
||||||
|
# commit time (sec) commit sha
|
||||||
|
# (used as Release; the timestamp prefix
|
||||||
|
# keeps same-day builds strictly ordered.)
|
||||||
|
|
||||||
|
%global _build_id_links none
|
||||||
|
%global debug_package %{nil}
|
||||||
|
%global __strip /usr/bin/true
|
||||||
|
|
||||||
|
%{!?cortex_version: %global cortex_version 0.0.0}
|
||||||
|
%if 0%{?cortex_prerelease:1}
|
||||||
|
%global cortex_release %{cortex_prerelease}
|
||||||
|
%else
|
||||||
|
%global cortex_release 1
|
||||||
|
%endif
|
||||||
|
|
||||||
|
Name: cortex
|
||||||
|
Version: %{cortex_version}
|
||||||
|
Release: %{cortex_release}%{?dist}
|
||||||
|
Summary: Inference gateway for multi-node GPU clusters (prebuilt)
|
||||||
|
|
||||||
|
License: GPL-3.0-or-later
|
||||||
|
URL: https://git.lair.cafe/helexa/cortex
|
||||||
|
|
||||||
|
Source0: cortex
|
||||||
|
Source1: cortex.service
|
||||||
|
Source2: cortex-sysusers.conf
|
||||||
|
Source3: cortex-firewalld.xml
|
||||||
|
Source4: cortex.example.toml
|
||||||
|
Source5: models.example.toml
|
||||||
|
Source6: LICENSE
|
||||||
|
|
||||||
|
ExclusiveArch: x86_64
|
||||||
|
|
||||||
|
Requires(pre): shadow-utils
|
||||||
|
Requires: systemd
|
||||||
|
Requires: firewalld-filesystem
|
||||||
|
|
||||||
|
Provides: user(cortex)
|
||||||
|
|
||||||
|
%description
|
||||||
|
Cortex is a Rust reverse-proxy that sits in front of multiple neuron
|
||||||
|
inference daemons and presents a unified OpenAI and Anthropic
|
||||||
|
compatible API surface.
|
||||||
|
|
||||||
|
This package wraps a binary built upstream in CI; the source-build
|
||||||
|
spec (cortex.spec) remains available for stable releases.
|
||||||
|
|
||||||
|
%prep
|
||||||
|
cp %{SOURCE0} ./cortex
|
||||||
|
cp %{SOURCE1} .
|
||||||
|
cp %{SOURCE2} .
|
||||||
|
cp %{SOURCE3} .
|
||||||
|
cp %{SOURCE4} .
|
||||||
|
cp %{SOURCE5} .
|
||||||
|
cp %{SOURCE6} .
|
||||||
|
|
||||||
|
%build
|
||||||
|
# Already built in the upstream CI build job.
|
||||||
|
|
||||||
|
%install
|
||||||
|
install -Dm755 cortex %{buildroot}%{_bindir}/cortex
|
||||||
|
install -Dm644 cortex.service %{buildroot}%{_unitdir}/cortex.service
|
||||||
|
install -Dm644 cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
|
||||||
|
install -Dm644 cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
|
install -dm755 %{buildroot}%{_sysconfdir}/cortex
|
||||||
|
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
|
||||||
|
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
|
%pre
|
||||||
|
getent group cortex >/dev/null || groupadd -r cortex
|
||||||
|
getent passwd cortex >/dev/null || \
|
||||||
|
useradd -r -g cortex -d /var/lib/cortex -s /sbin/nologin \
|
||||||
|
-c "Cortex inference gateway" cortex
|
||||||
|
|
||||||
|
%post
|
||||||
|
%systemd_post cortex.service
|
||||||
|
|
||||||
|
%preun
|
||||||
|
%systemd_preun cortex.service
|
||||||
|
|
||||||
|
%postun
|
||||||
|
%systemd_postun_with_restart cortex.service
|
||||||
|
|
||||||
|
%files
|
||||||
|
%license LICENSE
|
||||||
|
%{_bindir}/cortex
|
||||||
|
%{_unitdir}/cortex.service
|
||||||
|
%{_sysusersdir}/cortex.conf
|
||||||
|
%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
|
%dir %{_sysconfdir}/cortex
|
||||||
|
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
|
||||||
|
%config(noreplace) %{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
|
%changelog
|
||||||
|
* Mon May 18 2026 Gitea Actions <actions@git.lair.cafe> - %{cortex_version}-%{cortex_release}
|
||||||
|
- Prerelease build from upstream CI binary.
|
||||||
126
rpm/helexa-neuron-prerelease.spec
Normal file
126
rpm/helexa-neuron-prerelease.spec
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
# Prebuilt-binary spec for helexa-neuron flavoured by CUDA compute capability.
|
||||||
|
#
|
||||||
|
# Unlike helexa-neuron.spec (which builds from source via cargo), this
|
||||||
|
# spec wraps a pre-built `neuron-{flavour}` binary produced by an
|
||||||
|
# upstream CI job and packages it for rpm.lair.cafe. The %build phase
|
||||||
|
# is a no-op.
|
||||||
|
#
|
||||||
|
# Required defines at rpmbuild time:
|
||||||
|
# neuron_version e.g. "0.1.16"
|
||||||
|
# neuron_flavour e.g. "ada", "blackwell" — matches the CI build
|
||||||
|
# matrix's compute_cap label.
|
||||||
|
# neuron_prerelease e.g. "0.1.20260518140530.gitabcdef0"
|
||||||
|
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
|
||||||
|
# commit time (sec) commit sha
|
||||||
|
# (used as Release; the timestamp prefix
|
||||||
|
# keeps same-day builds strictly ordered.)
|
||||||
|
#
|
||||||
|
# One flavour can be installed at a time on a given host; flavour
|
||||||
|
# packages Conflict with each other.
|
||||||
|
|
||||||
|
%global _build_id_links none
|
||||||
|
%global debug_package %{nil}
|
||||||
|
%global __strip /usr/bin/true
|
||||||
|
|
||||||
|
%{!?neuron_version: %global neuron_version 0.0.0}
|
||||||
|
%{!?neuron_flavour: %global neuron_flavour blackwell}
|
||||||
|
%if 0%{?neuron_prerelease:1}
|
||||||
|
%global neuron_release %{neuron_prerelease}
|
||||||
|
%else
|
||||||
|
%global neuron_release 1
|
||||||
|
%endif
|
||||||
|
|
||||||
|
Name: helexa-neuron-%{neuron_flavour}
|
||||||
|
Version: %{neuron_version}
|
||||||
|
Release: %{neuron_release}%{?dist}
|
||||||
|
Summary: Per-node GPU inference daemon (candle, %{neuron_flavour} flavour)
|
||||||
|
|
||||||
|
License: GPL-3.0-or-later
|
||||||
|
URL: https://git.lair.cafe/helexa/cortex
|
||||||
|
|
||||||
|
Source0: neuron-%{neuron_flavour}
|
||||||
|
Source1: neuron.service
|
||||||
|
Source2: neuron-sysusers.conf
|
||||||
|
Source3: neuron-firewalld.xml
|
||||||
|
Source4: neuron.example.toml
|
||||||
|
Source5: LICENSE
|
||||||
|
|
||||||
|
ExclusiveArch: x86_64
|
||||||
|
|
||||||
|
# Binary links against the CUDA runtime, cuDNN, NCCL, etc. Suppress
|
||||||
|
# auto-detected exact soname deps — users may have CUDA from various
|
||||||
|
# sources (rpmfusion, nvidia-direct) at different compatible versions;
|
||||||
|
# a runtime dlopen failure surfaces a clearer error than rpm dep
|
||||||
|
# resolution would.
|
||||||
|
%global __requires_exclude ^lib(cuda|cudart|cudnn|cublas|cublasLt|curand|nvrtc|nccl)
|
||||||
|
|
||||||
|
Requires(pre): shadow-utils
|
||||||
|
Requires: systemd
|
||||||
|
Requires: firewalld-filesystem
|
||||||
|
|
||||||
|
Provides: helexa-neuron = %{neuron_version}-%{neuron_release}
|
||||||
|
Provides: user(neuron)
|
||||||
|
|
||||||
|
# Mutual exclusion across flavours and the source-build variant.
|
||||||
|
Conflicts: helexa-neuron
|
||||||
|
Conflicts: helexa-neuron-ada
|
||||||
|
Conflicts: helexa-neuron-ampere
|
||||||
|
Conflicts: helexa-neuron-blackwell
|
||||||
|
# (The Conflicts: with self is filtered by rpm at install time.)
|
||||||
|
|
||||||
|
%description
|
||||||
|
Neuron is the per-node daemon for cortex inference clusters. It
|
||||||
|
discovers local GPU hardware via nvidia-smi, runs in-process
|
||||||
|
inference via huggingface/candle, and exposes an HTTP API for model
|
||||||
|
lifecycle management (load, unload, list, inference endpoint).
|
||||||
|
|
||||||
|
This is the %{neuron_flavour} flavour, built for that CUDA compute
|
||||||
|
capability. Install the flavour matching the GPUs on this host.
|
||||||
|
|
||||||
|
%prep
|
||||||
|
cp %{SOURCE0} ./neuron
|
||||||
|
cp %{SOURCE1} .
|
||||||
|
cp %{SOURCE2} .
|
||||||
|
cp %{SOURCE3} .
|
||||||
|
cp %{SOURCE4} .
|
||||||
|
cp %{SOURCE5} .
|
||||||
|
|
||||||
|
%build
|
||||||
|
# Already built in the upstream CI build job (with --features cuda).
|
||||||
|
|
||||||
|
%install
|
||||||
|
install -Dm755 neuron %{buildroot}%{_bindir}/neuron
|
||||||
|
install -Dm644 neuron.service %{buildroot}%{_unitdir}/neuron.service
|
||||||
|
install -Dm644 neuron-sysusers.conf %{buildroot}%{_sysusersdir}/neuron.conf
|
||||||
|
install -Dm644 neuron-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/helexa-neuron.xml
|
||||||
|
install -dm755 %{buildroot}%{_sysconfdir}/neuron
|
||||||
|
install -Dm644 neuron.example.toml %{buildroot}%{_sysconfdir}/neuron/neuron.toml
|
||||||
|
|
||||||
|
%pre
|
||||||
|
getent group neuron >/dev/null || groupadd -r neuron
|
||||||
|
getent passwd neuron >/dev/null || \
|
||||||
|
useradd -r -g neuron -d /var/lib/neuron -s /sbin/nologin \
|
||||||
|
-G video,render \
|
||||||
|
-c "Neuron GPU node daemon" neuron
|
||||||
|
|
||||||
|
%post
|
||||||
|
%systemd_post neuron.service
|
||||||
|
|
||||||
|
%preun
|
||||||
|
%systemd_preun neuron.service
|
||||||
|
|
||||||
|
%postun
|
||||||
|
%systemd_postun_with_restart neuron.service
|
||||||
|
|
||||||
|
%files
|
||||||
|
%license LICENSE
|
||||||
|
%{_bindir}/neuron
|
||||||
|
%{_unitdir}/neuron.service
|
||||||
|
%{_sysusersdir}/neuron.conf
|
||||||
|
%{_prefix}/lib/firewalld/services/helexa-neuron.xml
|
||||||
|
%dir %{_sysconfdir}/neuron
|
||||||
|
%config(noreplace) %{_sysconfdir}/neuron/neuron.toml
|
||||||
|
|
||||||
|
%changelog
|
||||||
|
* Mon May 18 2026 Gitea Actions <actions@git.lair.cafe> - %{neuron_version}-%{neuron_release}
|
||||||
|
- Prerelease build from upstream CI binary (%{neuron_flavour} flavour).
|
||||||
1
rpm/rpmmacros
Normal file
1
rpm/rpmmacros
Normal file
@@ -0,0 +1 @@
|
|||||||
|
%_openpgp_sign_id @GPG_NAME@
|
||||||
256
script/deploy.sh
Executable file
256
script/deploy.sh
Executable file
@@ -0,0 +1,256 @@
|
|||||||
|
#!/bin/env bash
|
||||||
|
#
|
||||||
|
# Rolling deploy across the helexa fleet, driven by asset/manifest.yml.
|
||||||
|
# Installs / upgrades cortex on the gateway host and the appropriate
|
||||||
|
# helexa-neuron-<flavour> package on each neuron host, then restarts
|
||||||
|
# their services.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
REPO_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||||
|
MANIFEST="${REPO_DIR}/asset/manifest.yml"
|
||||||
|
|
||||||
|
if [[ ! -f "${MANIFEST}" ]]; then
|
||||||
|
echo "fatal: manifest not found at ${MANIFEST}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Parse the manifest with yq. NOTE: this expects the pip-installed yq
|
||||||
|
# (a jq wrapper using jq syntax) — `pip install yq`. The Fedora rpm
|
||||||
|
# `yq` is mikefarah/yq and uses different (yaml-native) syntax; if a
|
||||||
|
# host has that one instead these queries will fail.
|
||||||
|
cortex_host=$(yq -r '.cortex.host' "${MANIFEST}")
|
||||||
|
|
||||||
|
# Emit one TAB-separated 'host\tflavour' line per neuron.
|
||||||
|
mapfile -t neuron_entries < <(
|
||||||
|
yq -r '.neurons[] | .host + "\t" + .flavour' "${MANIFEST}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the installed package's "version-release" string, or
|
||||||
|
# "(not installed)" when rpm reports the package as absent. Capture
|
||||||
|
# rpm's output into a variable so its "package X is not installed"
|
||||||
|
# stdout message (rpm writes that to stdout, not stderr, when -q fails)
|
||||||
|
# doesn't leak into the result.
|
||||||
|
installed_nvr() {
|
||||||
|
local host="$1" pkg="$2"
|
||||||
|
local nvr
|
||||||
|
if nvr=$(ssh "${host}" "rpm -q --qf '%{version}-%{release}' ${pkg} 2>/dev/null"); then
|
||||||
|
echo "${nvr}"
|
||||||
|
else
|
||||||
|
echo "(not installed)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure the rpm.lair.cafe unstable repo is configured AND enabled on
|
||||||
|
# the remote host.
|
||||||
|
#
|
||||||
|
# The upstream .repo file at https://rpm.lair.cafe/lair-cafe-unstable.repo
|
||||||
|
# ships with `enabled=0` so a host that just fetched it won't start
|
||||||
|
# pulling unstable packages by accident. We have to explicitly flip
|
||||||
|
# enabled=1 via `dnf config-manager setopt`. Both addrepo and setopt
|
||||||
|
# are idempotent.
|
||||||
|
#
|
||||||
|
# Non-fatal — if either step fails the subsequent `dnf install` will
|
||||||
|
# surface a clearer diagnostic on its own.
|
||||||
|
ensure_lair_repo() {
|
||||||
|
local host="$1"
|
||||||
|
if ! ssh "${host}" "test -f /etc/yum.repos.d/lair-cafe-unstable.repo" 2>/dev/null; then
|
||||||
|
echo "[${host}] adding rpm.lair.cafe unstable repo"
|
||||||
|
if ! ssh "${host}" sudo dnf config-manager addrepo \
|
||||||
|
--from-repofile=https://rpm.lair.cafe/lair-cafe-unstable.repo \
|
||||||
|
>/dev/null 2>&1; then
|
||||||
|
echo "[${host}] WARNING: failed to add lair.cafe repo file (proceeding anyway)"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
# The .repo file ships enabled=0; flip it on. Cheap, idempotent.
|
||||||
|
if ! ssh "${host}" sudo dnf config-manager setopt \
|
||||||
|
lair-cafe-unstable.enabled=1 >/dev/null 2>&1; then
|
||||||
|
echo "[${host}] WARNING: failed to enable lair-cafe-unstable (proceeding anyway)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure libcudnn.so.9 is resolvable on the remote host so the
|
||||||
|
# neuron binary (built with --features cudnn) doesn't fail at startup
|
||||||
|
# with "cannot open shared object file: No such file or directory".
|
||||||
|
#
|
||||||
|
# Probes ldconfig first — if cuDNN was installed manually (.tar/.run
|
||||||
|
# install), it'll be cached by ldconfig and we don't touch it.
|
||||||
|
# Otherwise adds NVIDIA's RHEL9 CUDA repo (the Fedora 43 CUDA repo
|
||||||
|
# doesn't ship cuDNN packages — only the RHEL9 one does) and installs
|
||||||
|
# libcudnn9-cuda-13.
|
||||||
|
ensure_cudnn_runtime() {
|
||||||
|
local host="$1"
|
||||||
|
if ssh "${host}" "ldconfig -p | grep -q libcudnn.so.9" 2>/dev/null; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
echo "[${host}] installing cuDNN runtime"
|
||||||
|
if ! ssh "${host}" "test -f /etc/yum.repos.d/cuda-rhel9-x86_64.repo" 2>/dev/null; then
|
||||||
|
if ! ssh "${host}" sudo dnf config-manager addrepo \
|
||||||
|
--from-repofile=https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
|
||||||
|
>/dev/null 2>&1; then
|
||||||
|
echo "[${host}] WARNING: failed to add rhel9 CUDA repo (proceeding anyway)"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
if ! ssh "${host}" sudo dnf install -y libcudnn9-cuda-13 >/dev/null 2>&1; then
|
||||||
|
echo "[${host}] WARNING: failed to install libcudnn9-cuda-13"
|
||||||
|
echo "[${host}] neuron may fail to start; install cuDNN manually if so"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# True when the named package needs to be installed or upgraded on the
|
||||||
|
# remote host — either it's not present, or a newer version exists in
|
||||||
|
# the repo. False only when the installed version is current.
|
||||||
|
#
|
||||||
|
# `dnf check-update <pkg>` returns 0 when the package isn't installed
|
||||||
|
# at all (there's nothing to update), so we have to probe with rpm -q
|
||||||
|
# first to distinguish "absent" from "current". Other dnf failures
|
||||||
|
# collapse into "needs update" so the subsequent install step surfaces
|
||||||
|
# the real diagnostic rather than this check swallowing it.
|
||||||
|
needs_update() {
|
||||||
|
local host="$1" pkg="$2"
|
||||||
|
# Not installed → needs work.
|
||||||
|
if ! ssh "${host}" "rpm -q ${pkg}" >/dev/null 2>&1; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
# Installed; ask dnf whether the repo has something newer.
|
||||||
|
if ssh "${host}" sudo dnf check-update --refresh -q "${pkg}" >/dev/null 2>&1; then
|
||||||
|
return 1
|
||||||
|
else
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# True if the named package is currently installed on the remote host.
|
||||||
|
# Used to decide between `dnf install` (fresh) and `dnf upgrade` (stale):
|
||||||
|
# dnf5's `install` is a no-op when the package is already present at
|
||||||
|
# any version — it does NOT auto-upgrade to the latest available — so
|
||||||
|
# the wrong command silently leaves the host on an old build.
|
||||||
|
is_installed() {
|
||||||
|
local host="$1" pkg="$2"
|
||||||
|
ssh "${host}" "rpm -q ${pkg}" >/dev/null 2>&1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Install or upgrade the named package on the remote, picking the
|
||||||
|
# right dnf verb based on the installed-or-not state. Returns 0 with
|
||||||
|
# dnf's combined stdout/stderr captured in __DNF_OUTPUT__ on success,
|
||||||
|
# and 1 with the same captured output on failure.
|
||||||
|
__DNF_OUTPUT__=""
|
||||||
|
install_or_upgrade() {
|
||||||
|
local host="$1" pkg="$2"
|
||||||
|
local cmd
|
||||||
|
if is_installed "${host}" "${pkg}"; then
|
||||||
|
cmd="upgrade"
|
||||||
|
else
|
||||||
|
cmd="install"
|
||||||
|
fi
|
||||||
|
if __DNF_OUTPUT__=$(
|
||||||
|
ssh "${host}" sudo dnf "${cmd}" --refresh --allowerasing -y "${pkg}" 2>&1
|
||||||
|
); then
|
||||||
|
return 0
|
||||||
|
else
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# cortex (gateway)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
ensure_lair_repo "${cortex_host}"
|
||||||
|
cortex_nvr=$(installed_nvr "${cortex_host}" cortex)
|
||||||
|
if needs_update "${cortex_host}" cortex; then
|
||||||
|
echo "[${cortex_host}] cortex update available (current: ${cortex_nvr})"
|
||||||
|
# Stop the service only if the unit file exists — fresh installs
|
||||||
|
# don't have it, and `systemctl stop` on a missing unit returns
|
||||||
|
# non-zero, which would otherwise short-circuit the install branch
|
||||||
|
# under set -e.
|
||||||
|
if ssh "${cortex_host}" "[ ! -f /usr/lib/systemd/system/cortex.service ] || sudo systemctl stop cortex.service"; then
|
||||||
|
echo "[${cortex_host}] stopped cortex service"
|
||||||
|
if install_or_upgrade "${cortex_host}" cortex; then
|
||||||
|
cortex_nvr=$(installed_nvr "${cortex_host}" cortex)
|
||||||
|
echo "[${cortex_host}] installed/upgraded cortex to ${cortex_nvr}"
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] failed to install/upgrade cortex:"
|
||||||
|
echo "${__DNF_OUTPUT__}" | sed "s/^/[${cortex_host}] /"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] failed to stop cortex service"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] cortex is up to date (${cortex_nvr})"
|
||||||
|
ssh "${cortex_host}" sudo systemctl stop cortex.service || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Sync cortex.toml whether the package was upgraded or not — the config
|
||||||
|
# can change without a package bump.
|
||||||
|
if rsync \
|
||||||
|
--archive \
|
||||||
|
--compress \
|
||||||
|
--rsync-path 'sudo rsync' \
|
||||||
|
--chown root:root \
|
||||||
|
--chmod 644 \
|
||||||
|
"${REPO_DIR}/cortex.toml" \
|
||||||
|
"${cortex_host}:/etc/cortex/cortex.toml"; then
|
||||||
|
echo "[${cortex_host}] sync'd cortex.toml"
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] failed to sync cortex.toml"
|
||||||
|
fi
|
||||||
|
|
||||||
|
ssh "${cortex_host}" sudo systemctl daemon-reload
|
||||||
|
if ssh "${cortex_host}" systemctl is-active --quiet cortex.service; then
|
||||||
|
echo "[${cortex_host}] cortex service is active"
|
||||||
|
elif ssh "${cortex_host}" sudo systemctl start cortex.service; then
|
||||||
|
echo "[${cortex_host}] started cortex service"
|
||||||
|
else
|
||||||
|
echo "[${cortex_host}] failed to start cortex service"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# neuron (per-host, flavour from manifest)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
for entry in "${neuron_entries[@]}"; do
|
||||||
|
IFS=$'\t' read -r neuron_host neuron_flavour <<< "${entry}"
|
||||||
|
package="helexa-neuron-${neuron_flavour}"
|
||||||
|
|
||||||
|
ensure_lair_repo "${neuron_host}"
|
||||||
|
ensure_cudnn_runtime "${neuron_host}"
|
||||||
|
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
|
||||||
|
if needs_update "${neuron_host}" "${package}"; then
|
||||||
|
echo "[${neuron_host}] ${package} update available (current: ${neuron_nvr})"
|
||||||
|
if ssh "${neuron_host}" "[ ! -f /usr/lib/systemd/system/neuron.service ] || sudo systemctl stop neuron.service"; then
|
||||||
|
echo "[${neuron_host}] stopped neuron service"
|
||||||
|
# --allowerasing lets dnf swap out a previously-installed
|
||||||
|
# bare helexa-neuron or a different flavour without manual
|
||||||
|
# intervention. The Conflicts: clauses in the spec ensure
|
||||||
|
# only one flavour is ever resident.
|
||||||
|
if install_or_upgrade "${neuron_host}" "${package}"; then
|
||||||
|
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
|
||||||
|
echo "[${neuron_host}] installed/upgraded ${package} to ${neuron_nvr}"
|
||||||
|
# Ensure firewalld allows neuron port
|
||||||
|
ssh "${neuron_host}" "sudo firewall-cmd --query-service=helexa-neuron --quiet 2>/dev/null || sudo firewall-cmd --add-service=helexa-neuron --permanent && sudo firewall-cmd --reload" 2>/dev/null || true
|
||||||
|
if ssh "${neuron_host}" "sudo systemctl daemon-reload && sudo systemctl start neuron.service"; then
|
||||||
|
echo "[${neuron_host}] started neuron service"
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] failed to start neuron service"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] failed to install ${package}:"
|
||||||
|
echo "${__DNF_OUTPUT__}" | sed "s/^/[${neuron_host}] /"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] failed to stop neuron service"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] ${package} is up to date (${neuron_nvr})"
|
||||||
|
if ssh "${neuron_host}" systemctl is-active --quiet neuron.service; then
|
||||||
|
echo "[${neuron_host}] neuron service is active"
|
||||||
|
elif ssh "${neuron_host}" sudo systemctl start neuron.service; then
|
||||||
|
echo "[${neuron_host}] started neuron service"
|
||||||
|
else
|
||||||
|
echo "[${neuron_host}] failed to start neuron service"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
154
script/generate-packages-json.py
Executable file
154
script/generate-packages-json.py
Executable file
@@ -0,0 +1,154 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Parse RPM repodata and emit a packages.json manifest for the UI."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
RPM_NS = "http://linux.duke.edu/metadata/common"
|
||||||
|
OTHER_NS = "http://linux.duke.edu/metadata/other"
|
||||||
|
REPO_NS = "http://linux.duke.edu/metadata/repo"
|
||||||
|
|
||||||
|
|
||||||
|
def find_repodata_file(repodata_dir, data_type):
|
||||||
|
"""Read repomd.xml and return the path to a specific data type's file."""
|
||||||
|
repomd_path = os.path.join(repodata_dir, "repomd.xml")
|
||||||
|
tree = ET.parse(repomd_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
|
||||||
|
for data in root.findall(f"{{{REPO_NS}}}data"):
|
||||||
|
if data.get("type") == data_type:
|
||||||
|
location = data.find(f"{{{REPO_NS}}}location")
|
||||||
|
if location is not None:
|
||||||
|
href = location.get("href", "")
|
||||||
|
return os.path.join(os.path.dirname(repodata_dir), href)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def open_compressed(path):
|
||||||
|
"""Open a gzip or zstd compressed file for reading."""
|
||||||
|
if path.endswith(".zst"):
|
||||||
|
result = subprocess.run(
|
||||||
|
["zstdcat", path], capture_output=True, check=True
|
||||||
|
)
|
||||||
|
import io
|
||||||
|
return io.BytesIO(result.stdout)
|
||||||
|
else:
|
||||||
|
return gzip.open(path, "rb")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_primary(repodata_dir):
|
||||||
|
"""Parse primary.xml.{gz,zst} and return package metadata."""
|
||||||
|
path = find_repodata_file(repodata_dir, "primary")
|
||||||
|
if not path:
|
||||||
|
print("error: primary metadata not found in repomd.xml", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
packages = {}
|
||||||
|
with open_compressed(path) as f:
|
||||||
|
tree = ET.parse(f)
|
||||||
|
|
||||||
|
for pkg in tree.getroot().findall(f"{{{RPM_NS}}}package"):
|
||||||
|
if pkg.get("type") != "rpm":
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = pkg.findtext(f"{{{RPM_NS}}}name", "")
|
||||||
|
version_el = pkg.find(f"{{{RPM_NS}}}version")
|
||||||
|
ver = version_el.get("ver", "") if version_el is not None else ""
|
||||||
|
rel = version_el.get("rel", "") if version_el is not None else ""
|
||||||
|
arch = pkg.findtext(f"{{{RPM_NS}}}arch", "")
|
||||||
|
|
||||||
|
size_el = pkg.find(f"{{{RPM_NS}}}size")
|
||||||
|
size = int(size_el.get("package", "0")) if size_el is not None else 0
|
||||||
|
|
||||||
|
time_el = pkg.find(f"{{{RPM_NS}}}time")
|
||||||
|
build_time = int(time_el.get("build", "0")) if time_el is not None else 0
|
||||||
|
|
||||||
|
location_el = pkg.find(f"{{{RPM_NS}}}location")
|
||||||
|
filename = os.path.basename(location_el.get("href", "")) if location_el is not None else ""
|
||||||
|
|
||||||
|
key = f"{name}-{ver}-{rel}"
|
||||||
|
packages[key] = {
|
||||||
|
"name": name,
|
||||||
|
"version": ver,
|
||||||
|
"release": rel,
|
||||||
|
"arch": arch,
|
||||||
|
"summary": pkg.findtext(f"{{{RPM_NS}}}summary", ""),
|
||||||
|
"size": size,
|
||||||
|
"buildTime": build_time,
|
||||||
|
"rpmFilename": filename,
|
||||||
|
"changelog": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
return packages
|
||||||
|
|
||||||
|
|
||||||
|
def parse_other(repodata_dir, packages):
|
||||||
|
"""Parse other.xml.gz and attach changelog entries to packages."""
|
||||||
|
path = find_repodata_file(repodata_dir, "other")
|
||||||
|
if not path:
|
||||||
|
return
|
||||||
|
|
||||||
|
with open_compressed(path) as f:
|
||||||
|
tree = ET.parse(f)
|
||||||
|
|
||||||
|
for pkg in tree.getroot().findall(f"{{{OTHER_NS}}}package"):
|
||||||
|
name = pkg.get("name", "")
|
||||||
|
version_el = pkg.find(f"{{{OTHER_NS}}}version")
|
||||||
|
ver = version_el.get("ver", "") if version_el is not None else ""
|
||||||
|
rel = version_el.get("rel", "") if version_el is not None else ""
|
||||||
|
key = f"{name}-{ver}-{rel}"
|
||||||
|
|
||||||
|
if key not in packages:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for entry in pkg.findall(f"{{{OTHER_NS}}}changelog"):
|
||||||
|
packages[key]["changelog"].append({
|
||||||
|
"author": entry.get("author", ""),
|
||||||
|
"date": int(entry.get("date", "0")),
|
||||||
|
"text": (entry.text or "").strip(),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repodata-dir",
|
||||||
|
required=True,
|
||||||
|
help="path to the repodata/ directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
required=True,
|
||||||
|
help="path to write packages.json",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
required=True,
|
||||||
|
help="public base URL for the repo (e.g. https://rpm.lair.cafe/fedora/43/x86_64)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
packages = parse_primary(args.repodata_dir)
|
||||||
|
parse_other(args.repodata_dir, packages)
|
||||||
|
|
||||||
|
manifest = {
|
||||||
|
"generated": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"baseUrl": args.base_url,
|
||||||
|
"packages": list(packages.values()),
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(manifest, f, indent=2)
|
||||||
|
|
||||||
|
print(f"wrote {len(packages)} packages to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
162
script/validate-neuron.sh
Executable file
162
script/validate-neuron.sh
Executable file
@@ -0,0 +1,162 @@
|
|||||||
|
#!/bin/env bash
|
||||||
|
#
|
||||||
|
# End-to-end smoke test for a deployed neuron.
|
||||||
|
#
|
||||||
|
# Confirms the daemon is reachable, loads a small public Qwen3 GGUF,
|
||||||
|
# fires a reasoning probe at /v1/chat/completions, and prints the
|
||||||
|
# answer. Used to validate the candle harness on a real GPU host
|
||||||
|
# before trusting it for production traffic, and as a regression test
|
||||||
|
# after pushing new neuron builds.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# script/validate-neuron.sh [host] [model_id] [quant]
|
||||||
|
#
|
||||||
|
# Defaults:
|
||||||
|
# host = beast.hanzalova.internal
|
||||||
|
# model_id = unsloth/Qwen3-0.6B-GGUF (official Qwen3-*-GGUF repos
|
||||||
|
# ship Q8_0 only; unsloth's mirror ships the full Q-spectrum
|
||||||
|
# including Q4_K_M)
|
||||||
|
# quant = Q4_K_M
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
HOST="${1:-beast.hanzalova.internal}"
|
||||||
|
MODEL_ID="${2:-unsloth/Qwen3-0.6B-GGUF}"
|
||||||
|
# `${3-Q4_K_M}` (no colon) only uses the default when the arg is
|
||||||
|
# UNSET — passing an explicit empty string drives the dense path.
|
||||||
|
QUANT="${3-Q4_K_M}"
|
||||||
|
PORT="${NEURON_PORT:-13131}"
|
||||||
|
BASE="http://${HOST}:${PORT}"
|
||||||
|
|
||||||
|
# Reasoning probe — concrete, low-temperature answer that small models
|
||||||
|
# can still get right. "Paris" is a strong signal of basic competence
|
||||||
|
# beyond gibberish.
|
||||||
|
PROBE_PROMPT='What is the capital of Georgia (Caucasus)? Respond with the city name only, no punctuation.'
|
||||||
|
EXPECT_SUBSTR='Tbilisi'
|
||||||
|
# Qwen3 prepends <think>...</think> reasoning before the answer when the
|
||||||
|
# chat template enables thinking mode, which eats most of a small token
|
||||||
|
# budget. 256 leaves enough room for thinking + final answer.
|
||||||
|
MAX_TOKENS=256
|
||||||
|
|
||||||
|
# /models/load is synchronous — neuron blocks the response until the
|
||||||
|
# hf-hub download + GGUF parse + tensor materialisation is done. A
|
||||||
|
# fresh 0.6B-Q4_K_M is ~400 MB; on a slow link or cold cache that's
|
||||||
|
# easily a minute. Pick a generous ceiling.
|
||||||
|
LOAD_TIMEOUT=600
|
||||||
|
INFER_TIMEOUT=120
|
||||||
|
|
||||||
|
# Status messages go to stderr so command substitutions like
|
||||||
|
# `raw=$(run_probe)` capture only the function's intended return value
|
||||||
|
# (an HTTP body), not the progress chatter.
|
||||||
|
say() { printf '[%s] %s\n' "${HOST}" "$*" >&2; }
|
||||||
|
die() { say "FAIL: $*"; exit 1; }
|
||||||
|
|
||||||
|
probe_health() {
|
||||||
|
curl --silent --fail --max-time 5 "${BASE}/health" >/dev/null \
|
||||||
|
|| die "neuron not reachable at ${BASE}/health"
|
||||||
|
}
|
||||||
|
|
||||||
|
list_loaded_ids() {
|
||||||
|
# The manifest is YAML and uses yq; HTTP responses are JSON and use
|
||||||
|
# jq directly. pip-yq parses input as YAML by default, which trips
|
||||||
|
# on JSON content that happens to look like YAML aliases (chatcmpl
|
||||||
|
# ids, escaped quotes inside `<think>...</think>` blocks, etc.).
|
||||||
|
curl --silent --fail "${BASE}/models" | jq -r '.[].id'
|
||||||
|
}
|
||||||
|
|
||||||
|
is_loaded() {
|
||||||
|
list_loaded_ids 2>/dev/null | grep -Fxq "${MODEL_ID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
trigger_load() {
|
||||||
|
say "POST /models/load ${MODEL_ID} (quant=${QUANT:-<dense>}, device=[0])"
|
||||||
|
say " (synchronous; may take a minute on first run while HF downloads)"
|
||||||
|
# Build the payload via jq so the optional `quant` field is
|
||||||
|
# omitted entirely when empty — that's the signal to the harness
|
||||||
|
# to take the dense safetensors load path rather than GGUF.
|
||||||
|
local payload
|
||||||
|
if [[ -z "${QUANT}" ]]; then
|
||||||
|
payload=$(jq -n -c \
|
||||||
|
--arg id "${MODEL_ID}" \
|
||||||
|
'{model_id: $id, harness: "candle", devices: [0]}')
|
||||||
|
else
|
||||||
|
payload=$(jq -n -c \
|
||||||
|
--arg id "${MODEL_ID}" \
|
||||||
|
--arg q "${QUANT}" \
|
||||||
|
'{model_id: $id, harness: "candle", quant: $q, devices: [0]}')
|
||||||
|
fi
|
||||||
|
# --write-out captures the response code on a separate line so we
|
||||||
|
# can surface a real diagnostic instead of relying on --fail.
|
||||||
|
local resp http_code body
|
||||||
|
resp=$(curl --silent --show-error --max-time "${LOAD_TIMEOUT}" \
|
||||||
|
--write-out '\n__HTTP__%{http_code}' \
|
||||||
|
-X POST "${BASE}/models/load" \
|
||||||
|
-H 'content-type: application/json' \
|
||||||
|
--data "${payload}") || die "curl /models/load failed: $?"
|
||||||
|
http_code=$(echo "${resp}" | grep -oP '(?<=__HTTP__)\d+$' | tail -1)
|
||||||
|
body=$(echo "${resp}" | sed '$ s/__HTTP__.*$//')
|
||||||
|
if [[ "${http_code}" != "200" ]]; then
|
||||||
|
die "load returned HTTP ${http_code}: ${body}"
|
||||||
|
fi
|
||||||
|
say "load returned ${http_code}: ${body}"
|
||||||
|
}
|
||||||
|
|
||||||
|
run_probe() {
|
||||||
|
say "POST /v1/chat/completions (probe: ${PROBE_PROMPT})"
|
||||||
|
local payload
|
||||||
|
payload=$(jq -n -c \
|
||||||
|
--arg model "${MODEL_ID}" \
|
||||||
|
--arg content "${PROBE_PROMPT}" \
|
||||||
|
--argjson tokens "${MAX_TOKENS}" \
|
||||||
|
'{
|
||||||
|
model: $model,
|
||||||
|
messages: [{role: "user", content: $content}],
|
||||||
|
temperature: 0.1,
|
||||||
|
max_tokens: $tokens
|
||||||
|
}')
|
||||||
|
local resp http_code body
|
||||||
|
resp=$(curl --silent --show-error --max-time "${INFER_TIMEOUT}" \
|
||||||
|
--write-out '\n__HTTP__%{http_code}' \
|
||||||
|
-X POST "${BASE}/v1/chat/completions" \
|
||||||
|
-H 'content-type: application/json' \
|
||||||
|
--data "${payload}") || die "curl /v1/chat/completions failed: $?"
|
||||||
|
http_code=$(echo "${resp}" | grep -oP '(?<=__HTTP__)\d+$' | tail -1)
|
||||||
|
body=$(echo "${resp}" | sed '$ s/__HTTP__.*$//')
|
||||||
|
if [[ "${http_code}" != "200" ]]; then
|
||||||
|
die "inference returned HTTP ${http_code}: ${body}"
|
||||||
|
fi
|
||||||
|
echo "${body}"
|
||||||
|
}
|
||||||
|
|
||||||
|
say "validating neuron at ${BASE}"
|
||||||
|
probe_health
|
||||||
|
say "/health OK"
|
||||||
|
|
||||||
|
if is_loaded; then
|
||||||
|
say "${MODEL_ID} already loaded"
|
||||||
|
else
|
||||||
|
trigger_load
|
||||||
|
fi
|
||||||
|
|
||||||
|
raw=$(run_probe)
|
||||||
|
echo "---"
|
||||||
|
# Dump the raw JSON. Don't pipe through `yq -r '.'` — yq's default
|
||||||
|
# YAML output mode chokes on JSON strings that contain `<` (and the
|
||||||
|
# `<think>` markers Qwen3 emits during reasoning are a perfect
|
||||||
|
# example). The targeted `yq -r '.path'` calls below work fine
|
||||||
|
# because jq's path filter mode bypasses the YAML re-emit.
|
||||||
|
echo "${raw}"
|
||||||
|
echo "---"
|
||||||
|
|
||||||
|
content=$(echo "${raw}" | jq -r '.choices[0].message.content // empty')
|
||||||
|
if [[ -z "${content}" ]]; then
|
||||||
|
die "no content in chat completion response"
|
||||||
|
fi
|
||||||
|
say "assistant said: ${content}"
|
||||||
|
|
||||||
|
if echo "${content}" | grep -qiF "${EXPECT_SUBSTR}"; then
|
||||||
|
say "PASS — response contains expected substring '${EXPECT_SUBSTR}'"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
die "response did not contain '${EXPECT_SUBSTR}'"
|
||||||
|
fi
|
||||||
Reference in New Issue
Block a user