Compare commits
183 Commits
67b9b044d3
...
phase-2-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
61adff347a
|
|||
|
0af8c8d6e7
|
|||
|
435fd10902
|
|||
|
cb303832bc
|
|||
|
44008358c5
|
|||
|
2f387f33f8
|
|||
|
fc9a8c42a3
|
|||
|
7733eecba5
|
|||
|
fdc0adb738
|
|||
|
8fa1d1962e
|
|||
|
cad7552104
|
|||
|
1818dfb337
|
|||
|
5ed1140c97
|
|||
|
957f704efa
|
|||
|
1859777332
|
|||
|
6927286cab
|
|||
|
302ccfb982
|
|||
|
df0abfe4d4
|
|||
|
b9016571f6
|
|||
|
adbc52bfcd
|
|||
|
537a0fe7f2
|
|||
|
cbadfcf112
|
|||
|
3ecbb21ece
|
|||
|
0d841a4981
|
|||
|
0bbb9b752d
|
|||
|
5aac1ffc59
|
|||
|
ec2b6450b2
|
|||
|
a494c8d43c
|
|||
|
abbedf8d8a
|
|||
|
6cc14e925c
|
|||
|
1c16732668
|
|||
|
5a0861d639
|
|||
|
33652ac651
|
|||
|
c297a54074
|
|||
|
0121a1930f
|
|||
|
13f4c36aeb
|
|||
|
4a51a54554
|
|||
|
0609f1ac5d
|
|||
|
96fc379893
|
|||
|
e267f583e1
|
|||
|
e23d5011d0
|
|||
|
249b2e5c98
|
|||
|
c59da83636
|
|||
|
f05882369d
|
|||
|
bd04d7f580
|
|||
|
1e13889392
|
|||
|
6e1c1dd0fc
|
|||
|
35876954cd
|
|||
|
740299bd9d
|
|||
|
cdf0f4e66d
|
|||
|
c4954e0eed
|
|||
|
b4f3576d82
|
|||
|
76ab24d98c
|
|||
|
b179204fd3
|
|||
|
081b532387
|
|||
|
7c19da9361
|
|||
|
24e20dcb5c
|
|||
|
becf61b9c1
|
|||
|
b9e7a76a7a
|
|||
|
800498f530
|
|||
|
d3f2d50749
|
|||
|
2740e61a23
|
|||
|
67f79c868f
|
|||
|
fc6ef0ee0f
|
|||
|
1385979e3d
|
|||
|
0a1cfcd4d0
|
|||
|
ea0e0f7911
|
|||
|
aa88d37509
|
|||
|
0f00f72b47
|
|||
|
9b0ed0b57f
|
|||
|
dc2a803266
|
|||
|
e71181499e
|
|||
|
ee663e5e99
|
|||
|
34f9b77d9d
|
|||
|
f084aaab8e
|
|||
|
68a606a79c
|
|||
|
4aa71902d0
|
|||
|
bef159b21c
|
|||
|
8d7b099b36
|
|||
|
89d98d1fb2
|
|||
|
cc95fe28d9
|
|||
|
09c945f81e
|
|||
|
05dc0bad18
|
|||
|
10c151efa5
|
|||
|
44ae927e38
|
|||
|
1ebbe87651
|
|||
|
70eb6af42b
|
|||
|
d1a4aad91d
|
|||
|
95dc8745eb
|
|||
|
495d3f7c05
|
|||
|
5c4c8e0eba
|
|||
|
07c44d5db1
|
|||
|
e7eb3dab6a
|
|||
|
180274548d
|
|||
|
a70f317729
|
|||
|
c6022aa6b9
|
|||
|
9e31d8deca
|
|||
|
b400e8b704
|
|||
|
62ca125a68
|
|||
|
735945ee81
|
|||
|
f72dee094f
|
|||
|
d46d8d4f6c
|
|||
|
9b8bd146f6
|
|||
|
96d8755245
|
|||
|
12549c9aed
|
|||
|
46527d7804
|
|||
|
8d3194f992
|
|||
|
5436af9c73
|
|||
|
8e882c0757
|
|||
|
93421f48e2
|
|||
|
05e15f3597
|
|||
|
da068ded6d
|
|||
|
2a7ede0232
|
|||
|
18ae3c30ee
|
|||
|
1a0400131e
|
|||
|
1866b99a89
|
|||
|
60176e7c2e
|
|||
|
602e8e1471
|
|||
|
e9d0a75dd5
|
|||
|
6cf87e328f
|
|||
|
f9f5fa41b6
|
|||
|
ed4d71db09
|
|||
|
39010c779f
|
|||
|
57d7ef8d3c
|
|||
|
0e9671dd7d
|
|||
|
e29c9e35f0
|
|||
|
8a2334eacb
|
|||
|
aad314cdfa
|
|||
|
6779b7526a
|
|||
|
84f5662df1
|
|||
|
249c9442e8
|
|||
|
5e17081fb4
|
|||
|
03bed93fee
|
|||
|
4a5211d830
|
|||
|
6d2dc5ff1a
|
|||
|
b713dbe669
|
|||
|
5c957d08ec
|
|||
|
729317d1ef
|
|||
|
5c2bd1a1da
|
|||
|
3cccc2c56b
|
|||
|
7f797b0265
|
|||
|
5a0360c1d5
|
|||
|
472c0e8737
|
|||
|
|
b9d8e30058 | ||
|
25f75fe552
|
|||
|
3f94c50817
|
|||
|
3e1fb60076
|
|||
|
|
9bf987888c | ||
|
abe4ff7ccc
|
|||
|
7c3390a4e1
|
|||
|
2ff062da0e
|
|||
|
|
357f858a29 | ||
|
556e5293dc
|
|||
|
1d90238b01
|
|||
|
d99b25fb8a
|
|||
|
034da319f1
|
|||
|
|
7ece281617 | ||
|
3bb5b3c425
|
|||
|
|
9fa51ad874 | ||
|
9697fbae73
|
|||
|
|
2ce1060cb8 | ||
|
142e91c3f7
|
|||
|
|
52c8b4c983 | ||
|
4a9a4fc775
|
|||
|
53a3c1e157
|
|||
|
5c7d63c658
|
|||
|
|
f161412f91 | ||
|
ba5020138f
|
|||
|
209150771e
|
|||
|
|
7c60af3464 | ||
|
ada76b0153
|
|||
|
15ded3a5bd
|
|||
|
7befa882d5
|
|||
|
d03fae960a
|
|||
|
7b2235d56b
|
|||
|
54f9f3dc36
|
|||
|
caee8bba11
|
|||
|
324dfa05c5
|
|||
|
c85d50066e
|
|||
|
6c238f4557
|
|||
|
e42e8ee81f
|
|||
|
26e5e7ead8
|
|||
|
6dc717ebcd
|
343
.gitea/workflows/build-prerelease.yml
Normal file
343
.gitea/workflows/build-prerelease.yml
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
name: build-prerelease
|
||||||
|
|
||||||
|
# Manually-dispatched workflow that builds CUDA-flavoured neuron binaries
|
||||||
|
# (and a single cortex binary), packages each as a Fedora RPM, signs
|
||||||
|
# them, and publishes to the `unstable` channel at rpm.lair.cafe.
|
||||||
|
#
|
||||||
|
# Trigger from the Gitea UI: Actions → build-prerelease → Run workflow.
|
||||||
|
# Optionally provide a `ref` to build from a non-default branch.
|
||||||
|
#
|
||||||
|
# The published packages are versioned as e.g.
|
||||||
|
# helexa-neuron-blackwell-0.1.16-0.1.20260518T140530.gitabcdef0.fc43.x86_64
|
||||||
|
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
|
||||||
|
# commit time (s) commit sha
|
||||||
|
# so they sort BELOW the eventual 0.1.16-1 stable release, and so two
|
||||||
|
# commits on the same day are still strictly ordered by their commit
|
||||||
|
# timestamps (rather than by RPM-vercmp's alpha-vs-digit precedence
|
||||||
|
# on the SHA fragment).
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Auto-build on every push to main so the unstable channel tracks
|
||||||
|
# head without a manual dispatch step.
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
# Manual dispatch still available to build from a non-main ref.
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: "Git ref to build (branch / tag / commit). Defaults to the workflow's branch."
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
# Share the group with ci.yml so the two workflows can't run
|
||||||
|
# concurrently on the same `rust` runner (act reuses the workspace
|
||||||
|
# cache and races destroy each other's build files mid-compile).
|
||||||
|
# cancel-in-progress=false → workflows queue; if a newer push lands,
|
||||||
|
# the older run is still picked up by ci.yml's own ref-keyed
|
||||||
|
# concurrency (same group, queued).
|
||||||
|
group: cortex-runner-pool-${{ github.ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_INCREMENTAL: "0"
|
||||||
|
CARGO_TERM_COLOR: "always"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
prepare:
|
||||||
|
name: Resolve version stamps
|
||||||
|
runs-on: rust
|
||||||
|
outputs:
|
||||||
|
version: ${{ steps.info.outputs.version }}
|
||||||
|
release: ${{ steps.info.outputs.release }}
|
||||||
|
short_sha: ${{ steps.info.outputs.short_sha }}
|
||||||
|
commit_timestamp: ${{ steps.info.outputs.commit_timestamp }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- id: info
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml)
|
||||||
|
SHORT_SHA=$(git rev-parse --short=7 HEAD)
|
||||||
|
# Second-precise commit timestamp gives the release stamp a
|
||||||
|
# strictly monotonic numeric prefix. The earlier %Y%m%d-only
|
||||||
|
# form let same-day builds be ordered by RPM's rpmvercmp
|
||||||
|
# rules over the SHA, which is non-chronological — e.g.
|
||||||
|
# "git602e8e1" sorts newer than "gitf9f5fa4" purely because
|
||||||
|
# rpmvercmp ranks digit-prefixed segments above alpha ones.
|
||||||
|
# The SHA stays only as a debug identifier; sort order is
|
||||||
|
# decided entirely by the timestamp.
|
||||||
|
COMMIT_TIMESTAMP=$(git log -1 --format=%cd --date=format:%Y%m%d%H%M%S HEAD)
|
||||||
|
RELEASE="0.1.${COMMIT_TIMESTAMP}.git${SHORT_SHA}"
|
||||||
|
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "release=${RELEASE}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "commit_timestamp=${COMMIT_TIMESTAMP}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
build-cortex:
|
||||||
|
name: Build cortex binary
|
||||||
|
needs: prepare
|
||||||
|
# runner-rust image already provides rust/cargo/clippy/rustfmt via
|
||||||
|
# dnf — no rustup install step needed.
|
||||||
|
runs-on: rust
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Build cortex (release)
|
||||||
|
run: cargo build --release -p cortex-cli
|
||||||
|
|
||||||
|
- name: Stage binary
|
||||||
|
run: |
|
||||||
|
mkdir --parents artifacts
|
||||||
|
cp target/release/cortex artifacts/cortex
|
||||||
|
./artifacts/cortex --version || true
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: cortex-fc43
|
||||||
|
path: artifacts/cortex
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
|
build-neuron:
|
||||||
|
name: Build neuron-${{ matrix.flavour }}
|
||||||
|
needs: prepare
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- flavour: ampere
|
||||||
|
compute_cap: "86"
|
||||||
|
runner: cuda-13.0
|
||||||
|
cuda_home: /usr/local/cuda-13.0
|
||||||
|
build_jobs: 8
|
||||||
|
nvcc_threads: 4
|
||||||
|
cargo_features: "cuda cudnn flash-attn"
|
||||||
|
- flavour: ada
|
||||||
|
compute_cap: "89"
|
||||||
|
runner: cuda-13.0
|
||||||
|
cuda_home: /usr/local/cuda-13.0
|
||||||
|
build_jobs: 8
|
||||||
|
nvcc_threads: 4
|
||||||
|
cargo_features: "cuda cudnn flash-attn"
|
||||||
|
- flavour: blackwell
|
||||||
|
compute_cap: "120"
|
||||||
|
runner: cuda-13.0
|
||||||
|
cuda_home: /usr/local/cuda-13.0
|
||||||
|
build_jobs: 8
|
||||||
|
nvcc_threads: 4
|
||||||
|
cargo_features: "cuda cudnn flash-attn"
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Build neuron with CUDA (${{ matrix.flavour }})
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
export PATH="${{ matrix.cuda_home }}/bin:${PATH}"
|
||||||
|
export LD_LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LD_LIBRARY_PATH:-}"
|
||||||
|
export LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LIBRARY_PATH:-}"
|
||||||
|
cargo build --release -p neuron --features "${{ matrix.cargo_features }}"
|
||||||
|
env:
|
||||||
|
CUDA_COMPUTE_CAP: ${{ matrix.compute_cap }}
|
||||||
|
CARGO_BUILD_JOBS: ${{ matrix.build_jobs }}
|
||||||
|
NVCC_THREADS: ${{ matrix.nvcc_threads }}
|
||||||
|
|
||||||
|
- name: Stage binary
|
||||||
|
run: |
|
||||||
|
mkdir --parents artifacts
|
||||||
|
cp target/release/neuron artifacts/neuron-${{ matrix.flavour }}
|
||||||
|
file "artifacts/neuron-${{ matrix.flavour }}"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: neuron-${{ matrix.flavour }}-fc43
|
||||||
|
path: artifacts/neuron-${{ matrix.flavour }}
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
|
package-cortex:
|
||||||
|
name: Package cortex RPM
|
||||||
|
needs: [prepare, build-cortex]
|
||||||
|
runs-on: rpm
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
name: cortex-fc43
|
||||||
|
path: artifacts/
|
||||||
|
|
||||||
|
- name: Build RPM
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
rm -f ~/.rpmmacros
|
||||||
|
rpmdev-setuptree
|
||||||
|
cp artifacts/cortex ~/rpmbuild/SOURCES/
|
||||||
|
cp data/cortex.service ~/rpmbuild/SOURCES/
|
||||||
|
cp data/cortex-sysusers.conf ~/rpmbuild/SOURCES/
|
||||||
|
cp data/cortex-firewalld.xml ~/rpmbuild/SOURCES/
|
||||||
|
cp cortex.example.toml ~/rpmbuild/SOURCES/
|
||||||
|
cp models.example.toml ~/rpmbuild/SOURCES/
|
||||||
|
cp LICENSE ~/rpmbuild/SOURCES/
|
||||||
|
rpmbuild -bb rpm/cortex-prerelease.spec \
|
||||||
|
--define "cortex_version ${{ needs.prepare.outputs.version }}" \
|
||||||
|
--define "cortex_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||||
|
--undefine dist \
|
||||||
|
--define "dist .fc43"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: rpm-cortex-fc43
|
||||||
|
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
package-neuron:
|
||||||
|
name: Package helexa-neuron-${{ matrix.flavour }} RPM
|
||||||
|
needs: [prepare, build-neuron]
|
||||||
|
runs-on: rpm
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- flavour: ampere
|
||||||
|
- flavour: ada
|
||||||
|
- flavour: blackwell
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
name: neuron-${{ matrix.flavour }}-fc43
|
||||||
|
path: artifacts/
|
||||||
|
|
||||||
|
- name: Build RPM
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
rm -f ~/.rpmmacros
|
||||||
|
rpmdev-setuptree
|
||||||
|
cp artifacts/neuron-${{ matrix.flavour }} ~/rpmbuild/SOURCES/
|
||||||
|
cp data/neuron.service ~/rpmbuild/SOURCES/
|
||||||
|
cp data/neuron-sysusers.conf ~/rpmbuild/SOURCES/
|
||||||
|
cp data/neuron-firewalld.xml ~/rpmbuild/SOURCES/
|
||||||
|
cp neuron.example.toml ~/rpmbuild/SOURCES/
|
||||||
|
cp LICENSE ~/rpmbuild/SOURCES/
|
||||||
|
rpmbuild -bb rpm/helexa-neuron-prerelease.spec \
|
||||||
|
--define "neuron_version ${{ needs.prepare.outputs.version }}" \
|
||||||
|
--define "neuron_flavour ${{ matrix.flavour }}" \
|
||||||
|
--define "neuron_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||||
|
--undefine dist \
|
||||||
|
--define "dist .fc43"
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: rpm-neuron-${{ matrix.flavour }}-fc43
|
||||||
|
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish to rpm.lair.cafe (unstable)
|
||||||
|
needs: [package-cortex, package-neuron]
|
||||||
|
runs-on: rpm
|
||||||
|
concurrency:
|
||||||
|
group: rpm-publish
|
||||||
|
cancel-in-progress: false
|
||||||
|
env:
|
||||||
|
RPM_REPO_HOST: oolon.kosherinata.internal
|
||||||
|
FEDORA_VERSION: "43"
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Download all built RPMs
|
||||||
|
uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
path: rpms/
|
||||||
|
pattern: rpm-*-fc43
|
||||||
|
|
||||||
|
- name: Flatten RPM artifacts
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
find rpms/ -name '*.rpm' -exec mv --target-directory=rpms/ {} +
|
||||||
|
find rpms/ -mindepth 1 -type d -empty -delete
|
||||||
|
ls -la rpms/
|
||||||
|
|
||||||
|
- name: Check for sequoia-sq
|
||||||
|
run: |
|
||||||
|
if ! command -v sq &> /dev/null; then
|
||||||
|
echo "ERROR: sequoia-sq is not installed. Install with: sudo dnf install sequoia-sq"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Import signing key
|
||||||
|
env:
|
||||||
|
# Pass secrets via env so values stay out of the rendered shell
|
||||||
|
# script (which Gitea includes in step logs). Template
|
||||||
|
# expansion of ${{ secrets.X }} inside `run:` writes the literal
|
||||||
|
# value into the script and depends on Gitea's log masker to
|
||||||
|
# scrub it — fragile for multi-line keys.
|
||||||
|
RPM_SIGNING_KEY: ${{ secrets.RPM_SIGNING_KEY }}
|
||||||
|
RPM_SIGNING_KEY_ID: ${{ secrets.RPM_SIGNING_KEY_ID }}
|
||||||
|
run: |
|
||||||
|
echo "$RPM_SIGNING_KEY" | gpg --batch --import
|
||||||
|
fpr=$(gpg --batch --with-colons --list-keys "$RPM_SIGNING_KEY_ID" | awk -F: '/^fpr:/ { print $10; exit }')
|
||||||
|
echo "${fpr}:6:" | gpg --batch --import-ownertrust
|
||||||
|
sed "s/@GPG_NAME@/$RPM_SIGNING_KEY_ID/" rpm/rpmmacros > ~/.rpmmacros
|
||||||
|
|
||||||
|
- name: Sign RPMs
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
for rpm in rpms/*.rpm; do
|
||||||
|
echo "signing ${rpm}..."
|
||||||
|
rpm --addsign "${rpm}"
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Set up SSH for rsync
|
||||||
|
run: |
|
||||||
|
install --directory --mode 700 ~/.ssh
|
||||||
|
echo "${RSYNC_SSH_KEY}" | install --mode 600 /dev/stdin ~/.ssh/id_ed25519
|
||||||
|
env:
|
||||||
|
RSYNC_SSH_KEY: ${{ secrets.RSYNC_SSH_KEY }}
|
||||||
|
|
||||||
|
- name: Test SSH connectivity
|
||||||
|
run: |
|
||||||
|
ssh -o StrictHostKeyChecking=accept-new "gitea_ci@${RPM_REPO_HOST}" exit
|
||||||
|
|
||||||
|
- name: Ensure unstable repo directory exists
|
||||||
|
run: |
|
||||||
|
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||||
|
"mkdir --parents /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||||
|
|
||||||
|
- name: Sync RPMs to unstable repo
|
||||||
|
run: |
|
||||||
|
rsync \
|
||||||
|
--archive \
|
||||||
|
--verbose \
|
||||||
|
--chmod D755,F644 \
|
||||||
|
rpms/*.rpm \
|
||||||
|
"gitea_ci@${RPM_REPO_HOST}:/var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/"
|
||||||
|
|
||||||
|
- name: Update unstable repo metadata
|
||||||
|
run: |
|
||||||
|
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||||
|
"cd /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable && createrepo_c --update ."
|
||||||
|
|
||||||
|
- name: Generate packages.json manifest
|
||||||
|
run: |
|
||||||
|
scp script/generate-packages-json.py "gitea_ci@${RPM_REPO_HOST}:/tmp/"
|
||||||
|
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||||
|
"python3 /tmp/generate-packages-json.py \
|
||||||
|
--repodata-dir /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/repodata \
|
||||||
|
--output /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/packages.json \
|
||||||
|
--base-url https://rpm.lair.cafe/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||||
@@ -2,51 +2,181 @@ name: CI
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: ['**']
|
branches: ["**"]
|
||||||
tags: ['v*']
|
tags: ["v*"]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
|
# Share a concurrency group with build-prerelease.yml so the two
|
||||||
|
# workflows don't race on the same `rust` runner workspace (act's
|
||||||
|
# /root/.cache/act/<hash>/hostexecutor/ is shared across concurrent
|
||||||
|
# jobs and one job's checkout step nukes another's in-flight build
|
||||||
|
# files). cancel-in-progress=false → they queue; same-ref pushes
|
||||||
|
# coalesce per workflow via cancel-in-progress on each.
|
||||||
|
concurrency:
|
||||||
|
group: cortex-runner-pool-${{ github.ref }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_INCREMENTAL: "0"
|
||||||
|
RUSTC_WRAPPER: sccache
|
||||||
|
SCCACHE_BUCKET: sccache
|
||||||
|
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
|
||||||
|
SCCACHE_REGION: auto
|
||||||
|
SCCACHE_S3_USE_SSL: "false"
|
||||||
|
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||||
|
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||||
|
# fmt, clippy, and test all run in parallel on the same `rust` runner
|
||||||
|
# and would otherwise share /root/.cache/act/<hash>/hostexecutor/target/,
|
||||||
|
# racing each other's cargo temp files (.tmpXXXXXX) and failing builds
|
||||||
|
# mid-compile. Give each job its own target directory so the invocations
|
||||||
|
# don't collide. sccache still backs the actual rustc cache, so the
|
||||||
|
# rebuild penalty is small.
|
||||||
|
CARGO_TARGET_DIR: target-${{ github.job }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check:
|
fmt:
|
||||||
name: Format, lint, build, test
|
name: Format
|
||||||
runs-on: fedora
|
runs-on: rust
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
- run: cargo fmt --check --all
|
||||||
|
|
||||||
- name: Check formatting
|
clippy:
|
||||||
run: cargo fmt --check --all
|
name: Clippy
|
||||||
|
runs-on: rust
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
# sccache occasionally fails with spurious race-condition errors;
|
||||||
|
# retrying the same invocation succeeds without code changes.
|
||||||
|
# Allow up to 3 attempts before declaring real failure.
|
||||||
|
- name: Clippy (with retry)
|
||||||
|
run: |
|
||||||
|
for attempt in 1 2 3; do
|
||||||
|
echo "::group::clippy attempt ${attempt}"
|
||||||
|
if cargo clippy --workspace -- -D warnings; then
|
||||||
|
echo "::endgroup::"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
echo "::endgroup::"
|
||||||
|
echo "clippy failed on attempt ${attempt}"
|
||||||
|
if [ "${attempt}" -lt 3 ]; then
|
||||||
|
sleep 5
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "clippy failed after 3 attempts"
|
||||||
|
exit 1
|
||||||
|
- run: sccache --show-stats
|
||||||
|
|
||||||
- name: Clippy
|
test:
|
||||||
run: cargo clippy --workspace -- -D warnings
|
name: Test
|
||||||
|
runs-on: rust
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
# See the clippy job for why this is retried.
|
||||||
|
- name: Test (with retry)
|
||||||
|
run: |
|
||||||
|
for attempt in 1 2 3; do
|
||||||
|
echo "::group::test attempt ${attempt}"
|
||||||
|
if cargo test --workspace; then
|
||||||
|
echo "::endgroup::"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
echo "::endgroup::"
|
||||||
|
echo "test failed on attempt ${attempt}"
|
||||||
|
if [ "${attempt}" -lt 3 ]; then
|
||||||
|
sleep 5
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "test failed after 3 attempts"
|
||||||
|
exit 1
|
||||||
|
- run: sccache --show-stats
|
||||||
|
|
||||||
- name: Build
|
# Type-check the CUDA-only code path. Borrow-check-only — we
|
||||||
run: cargo build --workspace
|
# never run the tests here (the runner has no GPU). This catches
|
||||||
|
# the category of bug where a refactor compiles fine under the
|
||||||
|
# default feature set (which is what the `clippy` and `test` jobs
|
||||||
|
# exercise) but fails inside a `#[cfg(feature = "cuda")]` block.
|
||||||
|
# `runs-on: cuda-13.0` selects the runner that ships nvcc /
|
||||||
|
# cudarc's build prerequisites. The generic `rust` and `rpm`
|
||||||
|
# runners don't have them (the previous label `rpm` was tried
|
||||||
|
# first and tripped cudarc's `nvcc --version` build script —
|
||||||
|
# see commit history).
|
||||||
|
cuda-check:
|
||||||
|
name: CUDA type-check
|
||||||
|
runs-on: cuda-13.0
|
||||||
|
# The workflow-level env sets `RUSTC_WRAPPER: sccache` for the
|
||||||
|
# `rust` runner (where fmt/clippy/test live and sccache is
|
||||||
|
# installed). The `cuda-13.0` runner doesn't have sccache on
|
||||||
|
# PATH, so inheriting the wrapper makes cargo bail with
|
||||||
|
# `could not execute process `sccache rustc -vV` (never executed)`
|
||||||
|
# before borrow-check even starts. Clear it locally. Also clear
|
||||||
|
# SCCACHE_* so cargo doesn't try to contact the cache (the
|
||||||
|
# remote auth headers come from secrets that aren't present on
|
||||||
|
# this runner either). Lose the cache, keep the gate.
|
||||||
|
env:
|
||||||
|
RUSTC_WRAPPER: ""
|
||||||
|
SCCACHE_BUCKET: ""
|
||||||
|
SCCACHE_ENDPOINT: ""
|
||||||
|
SCCACHE_REGION: ""
|
||||||
|
SCCACHE_S3_USE_SSL: ""
|
||||||
|
AWS_ACCESS_KEY_ID: ""
|
||||||
|
AWS_SECRET_ACCESS_KEY: ""
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: cargo check --features cuda (with retry)
|
||||||
|
run: |
|
||||||
|
# act launches the step shell without /etc/profile, so the
|
||||||
|
# gitea_runner user's inherited PATH lacks /usr/local/cuda-13.0/bin.
|
||||||
|
# cudarc's build.rs:157 shells out to `nvcc --version` (because
|
||||||
|
# the neuron crate enables cuda-version-from-build-system) and
|
||||||
|
# panics with ENOENT if nvcc isn't resolvable. build-prerelease.yml
|
||||||
|
# does the same export — keep them in sync.
|
||||||
|
export PATH="/usr/local/cuda-13.0/bin:${PATH}"
|
||||||
|
export LD_LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LD_LIBRARY_PATH:-}"
|
||||||
|
export LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LIBRARY_PATH:-}"
|
||||||
|
for attempt in 1 2 3; do
|
||||||
|
echo "::group::cuda-check attempt ${attempt}"
|
||||||
|
if cargo check -p neuron --features cuda --all-targets; then
|
||||||
|
echo "::endgroup::"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
echo "::endgroup::"
|
||||||
|
echo "cuda-check failed on attempt ${attempt}"
|
||||||
|
if [ "${attempt}" -lt 3 ]; then
|
||||||
|
sleep 5
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "cuda-check failed after 3 attempts"
|
||||||
|
exit 1
|
||||||
|
|
||||||
- name: Test
|
srpm-cortex:
|
||||||
run: cargo test --workspace
|
name: Build cortex SRPM
|
||||||
|
runs-on: rpm
|
||||||
rpm:
|
needs: [fmt, clippy, test, cuda-check]
|
||||||
name: Build SRPM
|
|
||||||
runs-on: fedora
|
|
||||||
needs: check
|
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Determine version
|
- name: Determine version
|
||||||
id: version
|
id: version
|
||||||
run: |
|
run: |
|
||||||
VERSION="${GITHUB_REF#refs/tags/v}"
|
VERSION="${GITHUB_REF#refs/tags/v}"
|
||||||
echo "VERSION=${VERSION}" >> "$GITHUB_OUTPUT"
|
echo "VERSION=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||||
echo "Building version: ${VERSION}"
|
|
||||||
|
|
||||||
- name: Stamp version into spec
|
- name: Stamp version
|
||||||
run: |
|
run: |
|
||||||
VERSION="${{ steps.version.outputs.VERSION }}"
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
||||||
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
||||||
echo "Stamped version ${VERSION}"
|
|
||||||
|
- name: Generate changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: cortex.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
- name: Generate source tarball
|
- name: Generate source tarball
|
||||||
run: |
|
run: |
|
||||||
@@ -77,24 +207,148 @@ jobs:
|
|||||||
- name: Upload SRPM artifact
|
- name: Upload SRPM artifact
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: srpm
|
name: srpm-cortex
|
||||||
path: '*.src.rpm'
|
path: "*.src.rpm"
|
||||||
|
|
||||||
copr:
|
srpm-neuron:
|
||||||
name: Publish to COPR
|
name: Build neuron SRPM
|
||||||
runs-on: fedora
|
runs-on: rpm
|
||||||
needs: rpm
|
needs: [fmt, clippy, test, cuda-check]
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Determine version
|
||||||
|
id: version
|
||||||
|
run: |
|
||||||
|
VERSION="${GITHUB_REF#refs/tags/v}"
|
||||||
|
echo "VERSION=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Stamp version
|
||||||
|
run: |
|
||||||
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
|
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
||||||
|
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
|
||||||
|
|
||||||
|
- name: Generate changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: helexa-neuron.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
|
- name: Generate source tarball
|
||||||
|
run: |
|
||||||
|
set -ex
|
||||||
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
|
tar czf /tmp/helexa-neuron-${VERSION}.tar.gz \
|
||||||
|
--transform "s,^\.,helexa-neuron-${VERSION}," \
|
||||||
|
--exclude='./target' \
|
||||||
|
--exclude='./.git' \
|
||||||
|
--exclude='*.tar.gz' \
|
||||||
|
--exclude='*.src.rpm' \
|
||||||
|
.
|
||||||
|
mv /tmp/helexa-neuron-${VERSION}.tar.gz .
|
||||||
|
|
||||||
|
- name: Vendor Rust dependencies
|
||||||
|
run: |
|
||||||
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
|
cargo vendor vendor/
|
||||||
|
tar czf helexa-neuron-${VERSION}-vendor.tar.gz vendor/
|
||||||
|
rm -rf vendor/
|
||||||
|
|
||||||
|
- name: Build SRPM
|
||||||
|
run: |
|
||||||
|
rpmbuild -bs helexa-neuron.spec \
|
||||||
|
--define "_sourcedir $(pwd)" \
|
||||||
|
--define "_srcrpmdir $(pwd)"
|
||||||
|
|
||||||
|
- name: Upload SRPM artifact
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: srpm-neuron
|
||||||
|
path: "*.src.rpm"
|
||||||
|
|
||||||
|
copr-cortex:
|
||||||
|
name: Publish cortex to COPR
|
||||||
|
runs-on: fedora-43
|
||||||
|
needs: srpm-cortex
|
||||||
steps:
|
steps:
|
||||||
- name: Download SRPM
|
- name: Download SRPM
|
||||||
uses: actions/download-artifact@v3
|
uses: actions/download-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: srpm
|
name: srpm-cortex
|
||||||
|
|
||||||
- name: Configure copr-cli
|
- name: Publish to COPR
|
||||||
|
uses: https://git.lair.cafe/actions/copr-publish@v1
|
||||||
|
with:
|
||||||
|
project: helexa/helexa
|
||||||
|
srpm: "*.src.rpm"
|
||||||
|
copr-config: ${{ secrets.COPR_CONFIG }}
|
||||||
|
|
||||||
|
copr-neuron:
|
||||||
|
name: Publish neuron to COPR
|
||||||
|
runs-on: fedora-43
|
||||||
|
needs: srpm-neuron
|
||||||
|
steps:
|
||||||
|
- name: Download SRPM
|
||||||
|
uses: actions/download-artifact@v3
|
||||||
|
with:
|
||||||
|
name: srpm-neuron
|
||||||
|
|
||||||
|
- name: Publish to COPR
|
||||||
|
uses: https://git.lair.cafe/actions/copr-publish@v1
|
||||||
|
with:
|
||||||
|
project: helexa/helexa
|
||||||
|
srpm: "*.src.rpm"
|
||||||
|
copr-config: ${{ secrets.COPR_CONFIG }}
|
||||||
|
|
||||||
|
bump-version:
|
||||||
|
name: Bump version in source
|
||||||
|
runs-on: rust
|
||||||
|
needs: [copr-cortex, copr-neuron]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Determine version
|
||||||
|
id: version
|
||||||
|
run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Stamp version
|
||||||
run: |
|
run: |
|
||||||
mkdir -p ~/.config
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
echo "${{ secrets.COPR_CONFIG }}" > ~/.config/copr
|
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
|
||||||
|
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
|
||||||
|
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
|
||||||
|
cargo check --workspace 2>/dev/null || true
|
||||||
|
|
||||||
- name: Submit build to COPR
|
- name: Generate cortex changelog entry
|
||||||
run: copr-cli build cortex *.src.rpm
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: cortex.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
|
- name: Generate helexa-neuron changelog entry
|
||||||
|
uses: https://git.lair.cafe/actions/rpm-changelog@v1
|
||||||
|
with:
|
||||||
|
spec: helexa-neuron.spec
|
||||||
|
version: ${{ steps.version.outputs.VERSION }}
|
||||||
|
|
||||||
|
- name: Commit and push
|
||||||
|
env:
|
||||||
|
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
|
||||||
|
run: |
|
||||||
|
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||||
|
git config user.name "Gitea Actions"
|
||||||
|
git config user.email "actions@git.lair.cafe"
|
||||||
|
git add Cargo.toml Cargo.lock cortex.spec helexa-neuron.spec
|
||||||
|
if git diff --cached --quiet; then
|
||||||
|
echo "Nothing to commit for ${VERSION}"
|
||||||
|
else
|
||||||
|
git commit -m "chore: bump version to ${VERSION}"
|
||||||
|
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/helexa/cortex.git"
|
||||||
|
git push origin HEAD:main
|
||||||
|
fi
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,3 +4,6 @@
|
|||||||
.idea/
|
.idea/
|
||||||
.vscode/
|
.vscode/
|
||||||
cortex.toml
|
cortex.toml
|
||||||
|
models.toml
|
||||||
|
doc/plan/*
|
||||||
|
/target-cuda/
|
||||||
|
|||||||
485
CLAUDE.md
485
CLAUDE.md
@@ -84,6 +84,63 @@ Per-request: model, node, prompt_tokens, completion_tokens, total_tokens,
|
|||||||
tok_per_sec, time_to_first_token_ms, total_latency_ms.
|
tok_per_sec, time_to_first_token_ms, total_latency_ms.
|
||||||
Exposed as Prometheus histograms/counters on a separate port.
|
Exposed as Prometheus histograms/counters on a separate port.
|
||||||
|
|
||||||
|
### Per-device worker thread (neuron)
|
||||||
|
The neuron daemon dedicates one OS thread per CUDA device it loads
|
||||||
|
onto. That thread binds the device's `CudaContext` once at startup and
|
||||||
|
owns it for the daemon's lifetime; every model load, forward step,
|
||||||
|
KV-cache reset, VRAM query, NCCL init/sanity, NCCL all_reduce, and
|
||||||
|
model drop on that device routes through this thread via a
|
||||||
|
`std::sync::mpsc` job channel. Replies cross back via
|
||||||
|
`tokio::sync::oneshot`.
|
||||||
|
|
||||||
|
Three properties this gives us, in order of weight:
|
||||||
|
|
||||||
|
1. **Context locality.** cudarc binds the CUDA context per OS thread
|
||||||
|
via `cuCtxSetCurrent`. Before this refactor, ad-hoc
|
||||||
|
`tokio::task::spawn_blocking` calls bound the context onto a
|
||||||
|
different thread per request — and `device_vram_mb()` from an
|
||||||
|
async task bound it onto whichever tokio worker happened to be
|
||||||
|
running. Pinning the context to one named thread ends that.
|
||||||
|
2. **Drop safety.** Every `CudaSlice` in a `Tensor`, every
|
||||||
|
`cudarc::nccl::Comm`, and the `CudaContext` itself call `cuMemFree` /
|
||||||
|
`ncclCommDestroy` / `cuCtxDestroy` during `Drop` — and require the
|
||||||
|
right context current. With the worker owning the model slab,
|
||||||
|
`Drop` always runs on the right thread. The cudarc Drop constraint
|
||||||
|
is structurally enforced.
|
||||||
|
3. **Poisoning blast radius.** When a CUDA driver error makes the
|
||||||
|
context unrecoverable, the poison flag lives on the
|
||||||
|
`DeviceWorkerHandle` itself. Subsequent `submit()` calls fast-reject
|
||||||
|
at the channel boundary with a clear "device worker is poisoned"
|
||||||
|
error before any further CUDA work is attempted. The thread doesn't
|
||||||
|
exit (dropping the slab would re-touch the broken context) — it
|
||||||
|
enters a drain-only mode and replies error to everything until the
|
||||||
|
daemon restarts.
|
||||||
|
|
||||||
|
Tensors never escape the worker thread alive. Inference replies carry
|
||||||
|
`Vec<f32>` CPU-side logits; the async caller wraps them in a CPU
|
||||||
|
candle tensor and runs `apply_repeat_penalty` + `LogitsProcessor::sample`
|
||||||
|
without ever rebinding the device context. Sampled tokens come back as
|
||||||
|
`u32`; VRAM queries as `(u64, u64)`. The opaque `ArchHandle(u64)` and
|
||||||
|
`TpHandle(u64)` are the only "references" callers hold to loaded
|
||||||
|
models — they're indices into the worker's state slab, not pointers.
|
||||||
|
|
||||||
|
The TP worker subprocesses in `harness/tp/worker.rs` are the same
|
||||||
|
pattern out-of-process — a dedicated context-owning process per
|
||||||
|
non-zero NCCL rank. The in-process worker in `harness/device_worker/`
|
||||||
|
brings the discipline to rank 0.
|
||||||
|
|
||||||
|
CPU loads (`Device::Cpu` fallback when CUDA is unavailable) keep the
|
||||||
|
legacy `tokio::task::spawn_blocking + Arc<Mutex<ModelArch>>` path —
|
||||||
|
there's no context to own and the channel hop would only add latency.
|
||||||
|
Four `spawn_blocking` references in `harness/candle.rs` are deliberate
|
||||||
|
CPU fallback.
|
||||||
|
|
||||||
|
Canonical narrative lives in
|
||||||
|
`crates/neuron/src/harness/device_worker/mod.rs`'s module
|
||||||
|
doc-comment; touch points (the `Job` enum, the dispatch handlers, the
|
||||||
|
`DeviceWorkerState` struct) are in the sibling `jobs.rs` and
|
||||||
|
`dispatch.rs`.
|
||||||
|
|
||||||
## Tech stack
|
## Tech stack
|
||||||
|
|
||||||
- **Rust 2024 edition** — workspace with 4 crates
|
- **Rust 2024 edition** — workspace with 4 crates
|
||||||
@@ -125,7 +182,8 @@ automatically. Clippy warnings must be resolved, not suppressed with
|
|||||||
- One or more GPU nodes running mistral.rs on port 8080
|
- One or more GPU nodes running mistral.rs on port 8080
|
||||||
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
|
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
|
||||||
- Each node runs `mistralrs serve` on port 8080
|
- Each node runs `mistralrs serve` on port 8080
|
||||||
- Gateway listens on port 8000 (API) and 9100 (metrics)
|
- Gateway listens on port 31313 (API) and 31314 (metrics)
|
||||||
|
- neuron listens on port 13131 on each GPU host
|
||||||
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
|
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
|
||||||
|
|
||||||
## Conventions
|
## Conventions
|
||||||
@@ -277,15 +335,422 @@ histograms appear after a proxied request.
|
|||||||
Token-level metrics (tok/s, TTFT) deferred — requires parsing the
|
Token-level metrics (tok/s, TTFT) deferred — requires parsing the
|
||||||
response body or final SSE chunk, which is Phase 6b work.
|
response body or final SSE chunk, which is Phase 6b work.
|
||||||
|
|
||||||
### Phase 7 (lower priority): Agent sidecar
|
## 2026-04-15 addendum
|
||||||
|
|
||||||
**Goal:** Per-node binary that handles VRAM defrag restarts and
|
**Phases 1–6 complete.** The gateway proxies requests (streaming and
|
||||||
reports real VRAM usage via `nvidia-smi`.
|
non-streaming), routes by model name to the correct node, polls node
|
||||||
|
`/v1/models` for live state, evicts LRU models with pinning, translates
|
||||||
|
Anthropic ↔ OpenAI envelopes, and emits Prometheus metrics. CI is green.
|
||||||
|
|
||||||
This is deferred. The gateway handles the critical path (model
|
**Phase 7 onward** introduces `neuron` — the per-node daemon that replaces
|
||||||
lifecycle) entirely via the mistral.rs HTTP API. The agent adds
|
the placeholder `cortex-agent` crate — along with hardware discovery,
|
||||||
operational polish: automatic process restart when `lifecycle_cycles`
|
a harness abstraction (so cortex is not permanently wedded to mistral.rs),
|
||||||
exceeds threshold, real VRAM reporting (vs. estimates), and
|
and a model catalogue for placement decisions.
|
||||||
potentially GPU temperature/power monitoring.
|
|
||||||
|
|
||||||
**Defer until:** Phases 1-6 are merged and running in production.
|
|
||||||
|
### Architecture: cortex + neuron
|
||||||
|
|
||||||
|
cortex is the **control plane**. It exposes the unified API, routes
|
||||||
|
requests, manages model lifecycle across the fleet, and collects metrics.
|
||||||
|
|
||||||
|
neuron is the **node plane**. One instance runs on every GPU host. It:
|
||||||
|
- **Discovers** local hardware (GPU count, types, VRAM, CUDA compute
|
||||||
|
capability, driver version) and reports it to cortex.
|
||||||
|
- **Manages harnesses** — inference engines like mistral.rs, llama.cpp,
|
||||||
|
or ComfyUI. Each harness is a trait implementation. neuron starts,
|
||||||
|
stops, health-checks, and proxies to whichever harness is serving a
|
||||||
|
given model.
|
||||||
|
- **Manages model lifecycle** — load, unload, status — abstracting the
|
||||||
|
differences between harnesses (mistral.rs has HTTP lifecycle endpoints;
|
||||||
|
llama.cpp may need process management).
|
||||||
|
- **Reports runtime state** — per-device VRAM usage, GPU utilisation,
|
||||||
|
temperature, loaded models with actual VRAM consumption.
|
||||||
|
|
||||||
|
cortex never shells out to `nvidia-smi`, never touches systemd units,
|
||||||
|
and never talks directly to a harness. It talks only to neurons.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────┐
|
||||||
|
│ cortex │
|
||||||
|
│ (cortex-gateway) │
|
||||||
|
│ Router · Evictor │
|
||||||
|
│ Metrics · Translate│
|
||||||
|
└──┬──────┬────────┬──┘
|
||||||
|
│ │ │
|
||||||
|
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
||||||
|
│ neuron │ │ neuron │ │ neuron │
|
||||||
|
│ beast │ │ benjy │ │ quadbrat │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ harness: │ │harness:│ │ harness: │
|
||||||
|
│ mistralrs │ │mistral │ │ mistralrs │
|
||||||
|
│ (+ comfy) │ │rs │ │ │
|
||||||
|
└───────────┘ └────────┘ └───────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## The Harness trait
|
||||||
|
|
||||||
|
Defined in `cortex-core` so both cortex and neuron share the type
|
||||||
|
definitions. neuron provides the runtime implementations.
|
||||||
|
|
||||||
|
```rust
|
||||||
|
/// What an inference harness must do, from neuron's perspective.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Harness: Send + Sync {
|
||||||
|
/// Human-readable name (e.g. "mistralrs", "llamacpp", "comfyui").
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// Start the harness process if it is not already running.
|
||||||
|
async fn start(&self, config: &HarnessConfig) -> Result<()>;
|
||||||
|
|
||||||
|
/// Stop the harness process gracefully.
|
||||||
|
async fn stop(&self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Health check. Returns the harness process status.
|
||||||
|
async fn health(&self) -> HarnessHealth;
|
||||||
|
|
||||||
|
/// List models the harness knows about (loaded + unloaded).
|
||||||
|
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||||
|
|
||||||
|
/// Load a model with the given spec (quant, TP, device assignment).
|
||||||
|
async fn load_model(&self, spec: &ModelSpec) -> Result<()>;
|
||||||
|
|
||||||
|
/// Unload a model, freeing device memory.
|
||||||
|
async fn unload_model(&self, model_id: &str) -> Result<()>;
|
||||||
|
|
||||||
|
/// Return the URL where inference requests for this model should
|
||||||
|
/// be sent. None if the model is not loaded.
|
||||||
|
async fn inference_endpoint(&self, model_id: &str) -> Option<String>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The mistral.rs implementation wraps the HTTP API:
|
||||||
|
- `list_models` → `GET /v1/models`
|
||||||
|
- `load_model` → `POST /v1/models/reload`
|
||||||
|
- `unload_model` → `POST /v1/models/unload`
|
||||||
|
- `inference_endpoint` → returns the base URL (the model name routes
|
||||||
|
internally within mistral.rs)
|
||||||
|
- `start`/`stop` → manage the `mistralrs.service` systemd unit
|
||||||
|
|
||||||
|
A future llama.cpp implementation would manage per-model `llama-server`
|
||||||
|
processes (one process per loaded model, each on its own port).
|
||||||
|
|
||||||
|
|
||||||
|
## neuron API
|
||||||
|
|
||||||
|
neuron exposes an HTTP API on port 13131 that cortex polls and calls.
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /discovery
|
||||||
|
→ {
|
||||||
|
hostname, os, kernel,
|
||||||
|
cuda_version, driver_version,
|
||||||
|
devices: [{ index, name, vram_total_mb, compute_capability }],
|
||||||
|
harnesses: ["mistralrs", ...]
|
||||||
|
}
|
||||||
|
|
||||||
|
GET /health
|
||||||
|
→ {
|
||||||
|
uptime_secs,
|
||||||
|
devices: [{ index, vram_used_mb, vram_free_mb, utilization_pct, temp_c }]
|
||||||
|
}
|
||||||
|
|
||||||
|
GET /models
|
||||||
|
→ [{ id, harness, status, devices: [int], vram_used_mb }]
|
||||||
|
|
||||||
|
POST /models/load
|
||||||
|
← { model_id, harness, quant, tensor_parallel, devices: [int] }
|
||||||
|
→ { status: "loaded" | "loading" }
|
||||||
|
|
||||||
|
POST /models/unload
|
||||||
|
← { model_id }
|
||||||
|
→ { status: "unloaded" }
|
||||||
|
|
||||||
|
GET /models/{model_id}/endpoint
|
||||||
|
→ { url: "http://localhost:8080" }
|
||||||
|
```
|
||||||
|
|
||||||
|
cortex never constructs a harness-specific URL. It asks neuron for the
|
||||||
|
inference endpoint and proxies there.
|
||||||
|
|
||||||
|
|
||||||
|
## Discovery replaces static device config
|
||||||
|
|
||||||
|
cortex.toml no longer contains device types, VRAM sizes, or CUDA
|
||||||
|
architectures. That information comes from neuron's `/discovery`
|
||||||
|
endpoint. cortex.toml shrinks to:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[gateway]
|
||||||
|
listen = "0.0.0.0:31313"
|
||||||
|
metrics_listen = "0.0.0.0:31314"
|
||||||
|
|
||||||
|
[eviction]
|
||||||
|
strategy = "lru"
|
||||||
|
defrag_after_cycles = 50
|
||||||
|
|
||||||
|
[[neurons]]
|
||||||
|
name = "beast"
|
||||||
|
endpoint = "http://beast.hanzalova.internal:13131"
|
||||||
|
|
||||||
|
[[neurons]]
|
||||||
|
name = "benjy"
|
||||||
|
endpoint = "http://benjy.hanzalova.internal:13131"
|
||||||
|
|
||||||
|
[[neurons]]
|
||||||
|
name = "quadbrat"
|
||||||
|
endpoint = "http://quadbrat.hanzalova.internal:13131"
|
||||||
|
```
|
||||||
|
|
||||||
|
On startup and periodically, cortex calls `GET /discovery` and
|
||||||
|
`GET /health` on each neuron to build its topology map. The router
|
||||||
|
uses this topology — not config — to make placement decisions.
|
||||||
|
|
||||||
|
|
||||||
|
## Model catalogue
|
||||||
|
|
||||||
|
Model serving profiles live in a separate file (`models.toml`) because
|
||||||
|
they describe how to serve a model, not where. cortex matches these
|
||||||
|
profiles against the discovered topology to determine valid placements.
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[models]]
|
||||||
|
id = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
||||||
|
harness = "mistralrs"
|
||||||
|
quant = "Q4_K_M"
|
||||||
|
vram_mb = 19000
|
||||||
|
min_devices = 2
|
||||||
|
min_device_vram_mb = 10000
|
||||||
|
pinned_on = ["beast"] # optional: never evict from these neurons
|
||||||
|
|
||||||
|
[[models]]
|
||||||
|
id = "Qwen/Qwen3-VL-8B"
|
||||||
|
harness = "mistralrs"
|
||||||
|
quant = "Q8_0"
|
||||||
|
vram_mb = 10000
|
||||||
|
min_devices = 1
|
||||||
|
|
||||||
|
[[models]]
|
||||||
|
id = "Qwen/Qwen2.5-Coder-14B-Instruct"
|
||||||
|
harness = "mistralrs"
|
||||||
|
quant = "Q6_K"
|
||||||
|
vram_mb = 12000
|
||||||
|
min_devices = 1
|
||||||
|
pinned_on = ["benjy"]
|
||||||
|
```
|
||||||
|
|
||||||
|
The router consults the catalogue to answer: "model X needs 2 devices
|
||||||
|
with ≥10GB each; beast has 2× RTX 5090 at 32GB each; that's a valid
|
||||||
|
placement." This replaces the current per-node `pinned` list in config
|
||||||
|
and the hardcoded `vram_mb` per node.
|
||||||
|
|
||||||
|
|
||||||
|
## Revised repository layout
|
||||||
|
|
||||||
|
```
|
||||||
|
cortex/
|
||||||
|
├── Cargo.toml
|
||||||
|
├── cortex.toml # gateway config (neurons only)
|
||||||
|
├── models.toml # model catalogue
|
||||||
|
├── README.md
|
||||||
|
├── CLAUDE.md
|
||||||
|
├── crates/
|
||||||
|
│ ├── cortex-core/ # shared types
|
||||||
|
│ │ └── src/
|
||||||
|
│ │ ├── lib.rs
|
||||||
|
│ │ ├── config.rs # GatewayConfig, NeuronEndpoint
|
||||||
|
│ │ ├── catalogue.rs # ModelProfile, placement matching
|
||||||
|
│ │ ├── discovery.rs # DeviceInfo, DiscoveryResponse
|
||||||
|
│ │ ├── harness.rs # Harness trait, HarnessConfig, HarnessHealth
|
||||||
|
│ │ ├── node.rs # NodeState, ModelEntry, ModelStatus
|
||||||
|
│ │ ├── openai.rs # OpenAI envelope types
|
||||||
|
│ │ ├── anthropic.rs # Anthropic envelope types
|
||||||
|
│ │ ├── translate.rs # OpenAI <-> Anthropic translation
|
||||||
|
│ │ └── metrics.rs # RequestMetrics
|
||||||
|
│ ├── cortex-gateway/ # control plane (existing, modified)
|
||||||
|
│ │ └── src/
|
||||||
|
│ │ ├── lib.rs
|
||||||
|
│ │ ├── state.rs # CortexState (updated: discovery topology)
|
||||||
|
│ │ ├── router.rs # updated: catalogue + discovery placement
|
||||||
|
│ │ ├── proxy.rs # streaming proxy (unchanged)
|
||||||
|
│ │ ├── evictor.rs # updated: talks to neuron, not mistralrs
|
||||||
|
│ │ ├── poller.rs # updated: polls neuron, not mistralrs
|
||||||
|
│ │ ├── handlers.rs # axum handlers (unchanged API surface)
|
||||||
|
│ │ └── metrics.rs # prometheus exporter (unchanged)
|
||||||
|
│ ├── neuron/ # node plane (replaces cortex-agent)
|
||||||
|
│ │ └── src/
|
||||||
|
│ │ ├── main.rs # binary entrypoint, axum server on :13131
|
||||||
|
│ │ ├── discovery.rs # nvidia-smi, device enumeration
|
||||||
|
│ │ ├── health.rs # runtime GPU polling
|
||||||
|
│ │ ├── api.rs # HTTP handlers for /discovery, /models, etc.
|
||||||
|
│ │ ├── harness/
|
||||||
|
│ │ │ ├── mod.rs # Harness trait re-export, registry
|
||||||
|
│ │ │ ├── mistralrs.rs # mistral.rs HTTP API wrapper
|
||||||
|
│ │ │ └── llamacpp.rs # stub for future llama.cpp support
|
||||||
|
│ │ └── models.rs # local model lifecycle orchestration
|
||||||
|
│ └── cortex-cli/ # CLI entrypoint (unchanged)
|
||||||
|
│ └── src/
|
||||||
|
│ └── main.rs
|
||||||
|
└── tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
The `cortex-agent` crate is deleted. Its replacement is `neuron/`.
|
||||||
|
|
||||||
|
|
||||||
|
## Implementation plan (phases 7+)
|
||||||
|
|
||||||
|
Phases 1–6 are merged and passing CI. Each subsequent phase is a
|
||||||
|
branch → PR. CI (fmt, clippy, test) must pass before merge.
|
||||||
|
|
||||||
|
### Phase 7: neuron scaffold and discovery ✅
|
||||||
|
|
||||||
|
Completed. Deleted `cortex-agent`, created `crates/neuron/` (binary:
|
||||||
|
`neuron`). Added shared types to cortex-core: `discovery.rs`
|
||||||
|
(DeviceInfo, DiscoveryResponse, DeviceHealth, HealthResponse) and
|
||||||
|
`harness.rs` (Harness async trait, HarnessConfig, ModelSpec, ModelInfo).
|
||||||
|
|
||||||
|
neuron discovers GPUs via nvidia-smi, caches health readings, and
|
||||||
|
serves `GET /discovery` and `GET /health`. Pure parsing functions
|
||||||
|
separated from command execution for testability. 9 unit tests for
|
||||||
|
nvidia-smi CSV parsing, 3 integration tests for the HTTP endpoints.
|
||||||
|
|
||||||
|
### Phase 8: neuron harness — mistral.rs implementation ✅
|
||||||
|
|
||||||
|
Completed. Full `Harness` trait implementation for mistral.rs in
|
||||||
|
`neuron/src/harness/mistralrs.rs`: list_models, load_model, unload_model,
|
||||||
|
inference_endpoint, health, start/stop (systemd). `HarnessRegistry` in
|
||||||
|
`harness/mod.rs` maps harness name → `Box<dyn Harness>`, built from
|
||||||
|
`neuron.toml` config. Four new neuron API endpoints: `GET /models`,
|
||||||
|
`POST /models/load`, `POST /models/unload`, `GET /models/:id/endpoint`.
|
||||||
|
|
||||||
|
Config via `neuron.toml` (figment + env override). Integration test
|
||||||
|
covers full model lifecycle through neuron → mock mistral.rs backend.
|
||||||
|
|
||||||
|
### Phase 9: cortex talks to neurons ✅
|
||||||
|
|
||||||
|
Completed. Full refactor of cortex-gateway to talk to neurons:
|
||||||
|
|
||||||
|
- **Config**: `NodeConfig { endpoint, vram_mb, pinned }` replaced with
|
||||||
|
`NeuronEndpoint { name, endpoint }`. Hardware info comes from neuron
|
||||||
|
discovery, pinning from `models.toml` catalogue.
|
||||||
|
- **catalogue.rs**: `ModelProfile` with `pinned_on`, `ModelCatalogue`
|
||||||
|
with `is_pinned()` for eviction decisions.
|
||||||
|
- **Poller**: polls neuron's `GET /models` (ModelInfo format) instead
|
||||||
|
of mistralrs `/v1/models`.
|
||||||
|
- **Router**: asks neuron `GET /models/{id}/endpoint` for the inference
|
||||||
|
URL before proxying. Decouples cortex from knowing harness ports.
|
||||||
|
- **Evictor**: calls `POST {neuron}/models/unload` instead of
|
||||||
|
mistralrs directly. Uses catalogue for pinning.
|
||||||
|
- **Tests**: all 22 gateway tests updated to mock neuron API instead
|
||||||
|
of raw mistralrs. 36 total tests passing.
|
||||||
|
|
||||||
|
Topology-aware placement (min_devices, min_device_vram_mb) deferred —
|
||||||
|
the router currently routes based on polled model status. Catalogue
|
||||||
|
placement matching can be added incrementally.
|
||||||
|
|
||||||
|
### Phase 10: RPM packaging ✅
|
||||||
|
|
||||||
|
Completed. Both packages have RPM specs, systemd units, and example configs.
|
||||||
|
CI builds parallel SRPMs on tag push and publishes to separate COPR repos.
|
||||||
|
|
||||||
|
- `cortex.spec` — installs the `cortex` binary. Package name keeps the
|
||||||
|
short `cortex` because no Fedora package collides with it.
|
||||||
|
- `helexa-neuron.spec` — installs the `neuron` binary under package name
|
||||||
|
`helexa-neuron`. Renamed from bare `neuron` to avoid collision with
|
||||||
|
Fedora's NEURON neural-simulation package
|
||||||
|
(https://src.fedoraproject.org/rpms/neuron); binary, systemd unit,
|
||||||
|
system user, and config dir all stay named `neuron` since those are
|
||||||
|
project-local contexts.
|
||||||
|
- `data/cortex.service`, `data/neuron.service` — systemd units
|
||||||
|
- `cortex.example.toml`, `neuron.example.toml`, `models.example.toml`
|
||||||
|
- CI: parallel `srpm-cortex` + `srpm-neuron` jobs, then parallel COPR
|
||||||
|
publish to a single project `helexa/helexa` hosting both packages.
|
||||||
|
|
||||||
|
Install:
|
||||||
|
```sh
|
||||||
|
dnf copr enable helexa/helexa
|
||||||
|
dnf install cortex # gateway host
|
||||||
|
dnf install helexa-neuron # GPU nodes
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2026-05-18 addendum: candle-native pivot
|
||||||
|
|
||||||
|
Phases 11 (llama.cpp harness) and 12 (mistral.rs COPR) below are
|
||||||
|
**superseded**. The project no longer treats mistral.rs or llama.cpp as
|
||||||
|
dependencies — both are conceptually out of scope. neuron becomes a
|
||||||
|
candle-native inference daemon, with `Harness` retained as an
|
||||||
|
internal seam for adding future engines (vision/audio/diffusion) but
|
||||||
|
its only implementation being in-process candle.
|
||||||
|
|
||||||
|
The full staged plan for this pivot lives at
|
||||||
|
`~/.claude/plans/create-a-more-aggressive-calm-naur.md`. Summary:
|
||||||
|
|
||||||
|
- **Stage 1 (this commit):** delete `mistralrs.rs` and `llamacpp.rs`,
|
||||||
|
scaffold inert `CandleHarness`, drop `endpoint`/`systemd_unit` from
|
||||||
|
`HarnessConfig`, default no-op `start`/`stop` on the `Harness` trait.
|
||||||
|
- **Stages 2–4:** wire up candle model load/unload (quantized Qwen3
|
||||||
|
first), add OpenAI-compatible inference endpoint in neuron, then SSE
|
||||||
|
streaming.
|
||||||
|
- **Stages 5–6:** load-on-activation (default models in config) and
|
||||||
|
unload-on-deactivation (graceful shutdown).
|
||||||
|
- **Stages 7–8:** multi-GPU tensor parallelism and broader model/quant
|
||||||
|
coverage.
|
||||||
|
|
||||||
|
Sections of this document that describe mistral.rs HTTP behaviour
|
||||||
|
("mistral.rs API gotchas") are retained as historical context for
|
||||||
|
Phases 1–10 — they document what was true while the project depended
|
||||||
|
on mistral.rs. They do not describe current behaviour.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 11 (superseded): llama.cpp harness stub
|
||||||
|
|
||||||
|
~~Originally planned as a second engine to prove the harness
|
||||||
|
abstraction.~~ Replaced by the candle harness work in the 2026-05-18
|
||||||
|
addendum above. llama.cpp's any-model/any-hardware breadth is no
|
||||||
|
longer in scope for helexa.
|
||||||
|
|
||||||
|
### Phase 12 (superseded): mistral.rs COPR packaging
|
||||||
|
|
||||||
|
~~Originally planned to ship CUDA-versioned mistral.rs RPMs.~~ Replaced
|
||||||
|
by the candle harness work in the 2026-05-18 addendum above. With
|
||||||
|
mistral.rs out of the dependency tree, there is nothing to package.
|
||||||
|
|
||||||
|
## 2026-05-27 addendum: per-device worker thread
|
||||||
|
|
||||||
|
Replaced the ad-hoc `tokio::task::spawn_blocking` pattern that drove
|
||||||
|
every leader-side CUDA op with one dedicated OS thread per CUDA device,
|
||||||
|
permanently bound to that device's `CudaContext`. All leader-side
|
||||||
|
inference work (GGUF + dense + TP shard load, forward, kv-cache clear,
|
||||||
|
NCCL init/sanity, NCCL all_reduce, VRAM query, model drop) routes
|
||||||
|
through the worker via a `std::sync::mpsc` channel; tensors never
|
||||||
|
escape the worker thread alive. See "Per-device worker thread (neuron)"
|
||||||
|
above and `crates/neuron/src/harness/device_worker/mod.rs` for the
|
||||||
|
canonical narrative.
|
||||||
|
|
||||||
|
Motivated by the 2026-05-26 silent-hang on beast: a CUDA OOM cascade
|
||||||
|
poisoned the device context on whichever spawn_blocking thread caught
|
||||||
|
it, and subsequent requests stalled invisibly on the pool lock. After
|
||||||
|
the refactor, the same failure mode shows up in journalctl as
|
||||||
|
`prefill sample failed; logits unhealthy nan: 248320/248320` followed
|
||||||
|
by `failed, model marked poisoned`. The thread stays alive and rejects
|
||||||
|
subsequent requests at the channel boundary.
|
||||||
|
|
||||||
|
Landed in four PRs:
|
||||||
|
|
||||||
|
- **Phase 1** (`081b532`) — device_worker module + 8 VRAM-query sites
|
||||||
|
route through the worker. CPU build only; smoke on beast confirmed
|
||||||
|
a persistent `cuda-dev-0` thread.
|
||||||
|
- **Phase 2** (`b179204`) — single-GPU forward + clear_kv + drop via
|
||||||
|
the worker. `LoadedModel.arch_handle: Option<ArchHandle>` replaces
|
||||||
|
`Arc<Mutex<ModelArch>>` for CUDA loads. CPU keeps the legacy path.
|
||||||
|
- **Phase 3** (`76ab24d`) — TP forward + NCCL init/sanity + leader
|
||||||
|
KV-clear routed through the worker. `WorkerPool.leader_nccl` moves
|
||||||
|
into the worker's state. `TpLoadedModel.leader_handle: TpHandle`
|
||||||
|
replaces `Arc<Mutex<TpLeaderModel>>`. CUDA-only TP smoke deferred to
|
||||||
|
next deploy.
|
||||||
|
- **Phase 4** (`b4f3576`) — GGUF + dense + TP shard loads move onto
|
||||||
|
the worker. The `Job::TransferIn` / `Job::CloneLeaderComm` bridges
|
||||||
|
from Phases 2/3 deleted; `SendComm` newtype no longer needed in the
|
||||||
|
load path. `grep -rn spawn_blocking crates/neuron/src/harness/`
|
||||||
|
returns only deliberate CPU-fallback hits after this PR.
|
||||||
|
|||||||
2469
Cargo.lock
generated
2469
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
14
Cargo.toml
14
Cargo.toml
@@ -3,12 +3,13 @@ resolver = "2"
|
|||||||
members = [
|
members = [
|
||||||
"crates/cortex-core",
|
"crates/cortex-core",
|
||||||
"crates/cortex-gateway",
|
"crates/cortex-gateway",
|
||||||
"crates/cortex-agent",
|
|
||||||
"crates/cortex-cli",
|
"crates/cortex-cli",
|
||||||
|
"crates/neuron",
|
||||||
|
"crates/helexa-acp",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.1.0"
|
version = "0.1.16"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
license = "GPL-3.0-or-later"
|
license = "GPL-3.0-or-later"
|
||||||
repository = "https://git.lair.cafe/helexa/cortex"
|
repository = "https://git.lair.cafe/helexa/cortex"
|
||||||
@@ -27,7 +28,7 @@ serde = { version = "1", features = ["derive"] }
|
|||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
|
|
||||||
# http client (for proxying to mistralrs backends)
|
# http client (for proxying to neuron backends)
|
||||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||||
|
|
||||||
# observability
|
# observability
|
||||||
@@ -46,6 +47,12 @@ figment = { version = "0.10", features = ["toml", "env"] }
|
|||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
thiserror = "2"
|
thiserror = "2"
|
||||||
|
|
||||||
|
# async traits
|
||||||
|
async-trait = "0.1"
|
||||||
|
|
||||||
|
# CLI
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
|
||||||
# futures / streams (for SSE proxying)
|
# futures / streams (for SSE proxying)
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
tokio-stream = "0.1"
|
tokio-stream = "0.1"
|
||||||
@@ -54,4 +61,3 @@ eventsource-stream = "0.2"
|
|||||||
# workspace crates
|
# workspace crates
|
||||||
cortex-core = { path = "crates/cortex-core" }
|
cortex-core = { path = "crates/cortex-core" }
|
||||||
cortex-gateway = { path = "crates/cortex-gateway" }
|
cortex-gateway = { path = "crates/cortex-gateway" }
|
||||||
cortex-agent = { path = "crates/cortex-agent" }
|
|
||||||
|
|||||||
109
README.md
109
README.md
@@ -1,22 +1,23 @@
|
|||||||
# cortex
|
# cortex
|
||||||
|
|
||||||
A Rust reverse-proxy and fleet management layer for multi-node
|
A Rust reverse-proxy and fleet management layer for multi-node GPU inference
|
||||||
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
|
clusters. Cortex sits in front of one or more `neuron` daemons (each running
|
||||||
|
candle-based inference on a local GPU host) and presents a unified OpenAI +
|
||||||
|
Anthropic compatible API surface.
|
||||||
|
|
||||||
## Problem
|
## Problem
|
||||||
|
|
||||||
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
||||||
model affinities) requires a unified API surface that:
|
model affinities) requires a unified API surface that:
|
||||||
|
|
||||||
- Presents a **single `/v1/models` catalogue** merging every model across every
|
- Presents a **single `/v1/models` catalogue** merging every model that can be
|
||||||
node.
|
served by any neuron in the fleet.
|
||||||
- **Routes requests** to the correct node based on where a model is loaded (or
|
- **Routes requests** to the correct node based on where a model is loaded
|
||||||
*can* be loaded).
|
(or can be loaded), handling cold-load and eviction transparently.
|
||||||
- Manages **model lifecycle** — unload cold models, reload on demand, pin
|
- Manages **model lifecycle** — load on demand, unload cold models, pin
|
||||||
critical ones — using the mistral.rs
|
critical ones — by calling each neuron's `/models/{load,unload}` API.
|
||||||
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
|
|
||||||
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
||||||
every client in the homelab speaks whichever dialect it prefers.
|
every client speaks whichever dialect it prefers.
|
||||||
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
||||||
them as Prometheus counters/histograms.
|
them as Prometheus counters/histograms.
|
||||||
|
|
||||||
@@ -38,10 +39,9 @@ model affinities) requires a unified API surface that:
|
|||||||
└──┬──────┬────────┬──┘
|
└──┬──────┬────────┬──┘
|
||||||
│ │ │
|
│ │ │
|
||||||
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
||||||
│ gpu-large │ │gpu-med │ │ gpu-small │
|
│ neuron │ │ neuron │ │ neuron │
|
||||||
│ mistralrs │ │mistral │ │ mistralrs │
|
│ :13131 │ │ :13131 │ │ :13131 │
|
||||||
│ serve │ │rs serve│ │ serve │
|
│ candle │ │ candle │ │ candle │
|
||||||
│ :8080 │ │ :8080 │ │ :8080 │
|
|
||||||
└───────────┘ └────────┘ └───────────┘
|
└───────────┘ └────────┘ └───────────┘
|
||||||
private network (.internal)
|
private network (.internal)
|
||||||
```
|
```
|
||||||
@@ -50,70 +50,58 @@ model affinities) requires a unified API surface that:
|
|||||||
|
|
||||||
| Crate | Purpose |
|
| Crate | Purpose |
|
||||||
|---|---|
|
|---|---|
|
||||||
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
|
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic envelopes, harness trait, discovery types |
|
||||||
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
|
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, poller, metrics exporter |
|
||||||
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
|
| `neuron` | Per-node daemon: GPU discovery, in-process candle inference, model lifecycle API |
|
||||||
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
||||||
|
|
||||||
## Node setup
|
## Node setup
|
||||||
|
|
||||||
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
|
Each GPU node runs `neuron` (listening on `:13131`). Neuron uses
|
||||||
declared but start **unloaded** — mistral.rs lazy-loads on first request and
|
huggingface/candle for in-process inference — there is no external
|
||||||
the gateway can explicitly unload/reload via the HTTP API.
|
inference subprocess to manage.
|
||||||
|
|
||||||
Example node systemd unit:
|
Inside the daemon, every CUDA device gets one dedicated OS thread
|
||||||
|
(named `cuda-dev-N`) that owns the device's CUDA context for the
|
||||||
|
daemon's lifetime. Model loads, forward passes, KV-cache resets,
|
||||||
|
NCCL collectives, VRAM queries, and unloads all route through that
|
||||||
|
thread via a job channel; tensors never escape it alive. This pins
|
||||||
|
context binding to a known thread, makes the CUDA Drop contract
|
||||||
|
structurally safe, and isolates driver-error poisoning to one worker
|
||||||
|
rather than the whole process. See `CLAUDE.md` for the design
|
||||||
|
rationale and `crates/neuron/src/harness/device_worker/` for the code.
|
||||||
|
|
||||||
```ini
|
The neuron RPM (`helexa-neuron`) ships a systemd unit:
|
||||||
# /etc/systemd/system/mistralrs.service
|
|
||||||
[Unit]
|
|
||||||
Description=mistral.rs inference server
|
|
||||||
After=network-online.target
|
|
||||||
Wants=network-online.target
|
|
||||||
|
|
||||||
[Service]
|
```sh
|
||||||
Type=simple
|
dnf copr enable helexa/helexa
|
||||||
ExecStart=/usr/local/bin/mistralrs serve \
|
dnf install helexa-neuron
|
||||||
--from-config /etc/mistralrs/config.toml \
|
systemctl enable --now neuron
|
||||||
--port 8080
|
|
||||||
Restart=on-failure
|
|
||||||
RestartSec=5
|
|
||||||
Environment=CUDA_VISIBLE_DEVICES=0,1
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Gateway config
|
## Gateway config
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
# cortex.toml
|
# /etc/cortex/cortex.toml
|
||||||
[gateway]
|
[gateway]
|
||||||
listen = "0.0.0.0:8000"
|
listen = "0.0.0.0:31313"
|
||||||
metrics_listen = "0.0.0.0:9100"
|
metrics_listen = "0.0.0.0:31314"
|
||||||
|
|
||||||
[eviction]
|
[eviction]
|
||||||
strategy = "lru" # lru | priority
|
strategy = "lru" # lru | priority
|
||||||
defrag_after_cycles = 50
|
defrag_after_cycles = 50
|
||||||
|
|
||||||
[[nodes]]
|
[[neurons]]
|
||||||
name = "gpu-large"
|
name = "beast"
|
||||||
endpoint = "http://gpu-large.internal:8080"
|
endpoint = "http://beast.internal:13131"
|
||||||
vram_mb = 49_152 # e.g. 2x RTX 4090
|
|
||||||
pinned = ["your-org/large-model"]
|
|
||||||
|
|
||||||
[[nodes]]
|
[[neurons]]
|
||||||
name = "gpu-medium"
|
name = "benjy"
|
||||||
endpoint = "http://gpu-medium.internal:8080"
|
endpoint = "http://benjy.internal:13131"
|
||||||
vram_mb = 24_576 # e.g. RTX 4090
|
|
||||||
pinned = ["your-org/medium-model"]
|
|
||||||
|
|
||||||
[[nodes]]
|
|
||||||
name = "gpu-small"
|
|
||||||
endpoint = "http://gpu-small.internal:8080"
|
|
||||||
vram_mb = 12_288 # e.g. RTX 3060
|
|
||||||
pinned = ["your-org/embedding-model"]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Model placement profiles live in `models.toml` — see `models.example.toml`.
|
||||||
|
|
||||||
## Building
|
## Building
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
@@ -131,19 +119,20 @@ cargo clippy --workspace -- -D warnings # warnings are errors
|
|||||||
cargo test --workspace # all tests must pass
|
cargo test --workspace # all tests must pass
|
||||||
```
|
```
|
||||||
|
|
||||||
Tagged releases (`v*`) additionally build an SRPM and publish to COPR.
|
Tagged releases (`v*`) additionally build SRPMs for both `cortex` and
|
||||||
|
`helexa-neuron` and publish to COPR.
|
||||||
|
|
||||||
## Running
|
## Running
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
# start the gateway
|
# start the gateway
|
||||||
cortex serve --config cortex.toml
|
cortex serve --config /etc/cortex/cortex.toml
|
||||||
|
|
||||||
# check fleet status
|
# check fleet status
|
||||||
cortex status
|
cortex status
|
||||||
|
|
||||||
# list all models across nodes
|
# list all models across nodes
|
||||||
curl http://localhost:8000/v1/models
|
curl http://localhost:31313/v1/models
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|||||||
30
asset/manifest.yml
Normal file
30
asset/manifest.yml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Helexa fleet manifest.
|
||||||
|
#
|
||||||
|
# Drives rolling deploys via script/deploy.sh and serves as the source
|
||||||
|
# of truth for which hosts run cortex vs neuron, and which CUDA
|
||||||
|
# compute-capability flavour each neuron host needs.
|
||||||
|
#
|
||||||
|
# Flavour ↔ NVIDIA generation ↔ compute cap:
|
||||||
|
# ampere sm_86 (RTX 30 series — e.g. 3060)
|
||||||
|
# ada sm_89 (RTX 40 series — e.g. 4090)
|
||||||
|
# blackwell sm_120 (RTX 50 series — e.g. 5090)
|
||||||
|
#
|
||||||
|
# The flavour determines which RPM is installed on a given neuron host:
|
||||||
|
# helexa-neuron-<flavour>. Only one flavour may be installed at a time
|
||||||
|
# (the packages Conflict: with each other).
|
||||||
|
|
||||||
|
cortex:
|
||||||
|
host: hanzalova.internal
|
||||||
|
|
||||||
|
neurons:
|
||||||
|
- host: beast.hanzalova.internal
|
||||||
|
flavour: blackwell
|
||||||
|
gpu: "2x RTX 5090"
|
||||||
|
|
||||||
|
- host: benjy.hanzalova.internal
|
||||||
|
flavour: ada
|
||||||
|
gpu: "RTX 4090"
|
||||||
|
|
||||||
|
- host: quadbrat.hanzalova.internal
|
||||||
|
flavour: ampere
|
||||||
|
gpu: "RTX 3060"
|
||||||
24
asset/neuron/beast.toml
Normal file
24
asset/neuron/beast.toml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# neuron.toml for beast.hanzalova.internal
|
||||||
|
#
|
||||||
|
# 2x RTX 5090 (32 GB each) — TP-2 capable. Pre-warms Qwen3.6-27B with
|
||||||
|
# q5k ISQ across both GPUs at activation, matching the validate-neuron
|
||||||
|
# invocation: `validate-neuron.sh beast.hanzalova.internal
|
||||||
|
# Qwen/Qwen3.6-27B q5k 2`.
|
||||||
|
#
|
||||||
|
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml. Edits
|
||||||
|
# take effect on the next deploy.sh run (which stops + restarts the
|
||||||
|
# service so default_models is re-read at activation).
|
||||||
|
|
||||||
|
port = 13131
|
||||||
|
|
||||||
|
[[harnesses]]
|
||||||
|
name = "candle"
|
||||||
|
|
||||||
|
[harness.candle]
|
||||||
|
|
||||||
|
[[default_models]]
|
||||||
|
model_id = "Qwen/Qwen3.6-27B"
|
||||||
|
harness = "candle"
|
||||||
|
quant = "q6k"
|
||||||
|
tensor_parallel = 2
|
||||||
|
devices = [0, 1]
|
||||||
19
asset/neuron/benjy.toml
Normal file
19
asset/neuron/benjy.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# neuron.toml for benjy.hanzalova.internal
|
||||||
|
#
|
||||||
|
# 1x RTX 4090 (24 GB) — largest single-GPU host on the fleet. Pre-warms
|
||||||
|
# Qwen3-8B (bf16, ~18 GB), leaving ~6 GB for KV cache + activations on
|
||||||
|
# moderate-length contexts.
|
||||||
|
#
|
||||||
|
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml.
|
||||||
|
|
||||||
|
port = 13131
|
||||||
|
|
||||||
|
[[harnesses]]
|
||||||
|
name = "candle"
|
||||||
|
|
||||||
|
[harness.candle]
|
||||||
|
|
||||||
|
[[default_models]]
|
||||||
|
model_id = "Qwen/Qwen3-8B"
|
||||||
|
harness = "candle"
|
||||||
|
devices = [0]
|
||||||
19
asset/neuron/quadbrat.toml
Normal file
19
asset/neuron/quadbrat.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# neuron.toml for quadbrat.hanzalova.internal
|
||||||
|
#
|
||||||
|
# 1x RTX 3060 (12 GB) — small / quantised tier. Pre-warms Qwen3-1.7B
|
||||||
|
# (bf16, ~4 GB), leaving ~7 GB for KV cache so long contexts on a small
|
||||||
|
# model still have plenty of room.
|
||||||
|
#
|
||||||
|
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml.
|
||||||
|
|
||||||
|
port = 13131
|
||||||
|
|
||||||
|
[[harnesses]]
|
||||||
|
name = "candle"
|
||||||
|
|
||||||
|
[harness.candle]
|
||||||
|
|
||||||
|
[[default_models]]
|
||||||
|
model_id = "Qwen/Qwen3-1.7B"
|
||||||
|
harness = "candle"
|
||||||
|
devices = [0]
|
||||||
@@ -3,22 +3,22 @@
|
|||||||
# Copy to cortex.toml and adjust for your environment.
|
# Copy to cortex.toml and adjust for your environment.
|
||||||
#
|
#
|
||||||
# Environment variable overrides use CORTEX_ prefix with __ separators:
|
# Environment variable overrides use CORTEX_ prefix with __ separators:
|
||||||
# CORTEX_GATEWAY__LISTEN=0.0.0.0:9000
|
# CORTEX_GATEWAY__LISTEN=0.0.0.0:31313
|
||||||
|
|
||||||
[gateway]
|
[gateway]
|
||||||
listen = "0.0.0.0:8000"
|
listen = "0.0.0.0:31313"
|
||||||
metrics_listen = "0.0.0.0:9100"
|
metrics_listen = "0.0.0.0:31314"
|
||||||
|
|
||||||
[eviction]
|
[eviction]
|
||||||
strategy = "lru"
|
strategy = "lru"
|
||||||
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
|
# Restart neurons after this many load/unload cycles to defragment VRAM.
|
||||||
# Set to 0 to disable.
|
# Set to 0 to disable.
|
||||||
defrag_after_cycles = 50
|
defrag_after_cycles = 50
|
||||||
|
|
||||||
# -- Nodes ---------------------------------------------------------------
|
# -- Nodes ---------------------------------------------------------------
|
||||||
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
|
# Each [[nodes]] entry declares a neuron daemon in the fleet.
|
||||||
# Models are discovered by polling the node's /v1/models endpoint.
|
# Models are discovered by polling the neuron's /models endpoint.
|
||||||
# Pinned models are never evicted.
|
# Pinned models (see models.toml) are never evicted.
|
||||||
|
|
||||||
[[nodes]]
|
[[nodes]]
|
||||||
name = "gpu-large"
|
name = "gpu-large"
|
||||||
|
|||||||
91
cortex.spec
91
cortex.spec
@@ -1,7 +1,7 @@
|
|||||||
Name: cortex
|
Name: cortex
|
||||||
Version: 0.1.0
|
Version: 0.1.16
|
||||||
Release: 1%{?dist}
|
Release: 1%{?dist}
|
||||||
Summary: Inference gateway for multi-node mistral.rs clusters
|
Summary: Inference gateway for multi-node GPU clusters
|
||||||
|
|
||||||
License: GPL-3.0-or-later
|
License: GPL-3.0-or-later
|
||||||
URL: https://git.lair.cafe/helexa/cortex
|
URL: https://git.lair.cafe/helexa/cortex
|
||||||
@@ -13,13 +13,30 @@ ExclusiveArch: x86_64
|
|||||||
BuildRequires: rust >= 1.85
|
BuildRequires: rust >= 1.85
|
||||||
BuildRequires: cargo
|
BuildRequires: cargo
|
||||||
BuildRequires: gcc
|
BuildRequires: gcc
|
||||||
|
BuildRequires: gcc-c++
|
||||||
|
BuildRequires: cmake
|
||||||
|
BuildRequires: perl-interpreter
|
||||||
|
BuildRequires: pkgconfig(openssl)
|
||||||
BuildRequires: systemd-rpm-macros
|
BuildRequires: systemd-rpm-macros
|
||||||
|
|
||||||
|
Requires(pre): shadow-utils
|
||||||
|
Requires: systemd
|
||||||
|
Requires: firewalld-filesystem
|
||||||
|
|
||||||
|
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
|
||||||
|
# from our .service file and emits Requires: user(cortex)/group(cortex).
|
||||||
|
# rpm's sysusers provides-generator emits the unversioned form for groups
|
||||||
|
# but only a versioned user(cortex) = <base64> for users with GECOS/home/
|
||||||
|
# shell. Provide the unversioned user(cortex) explicitly so dnf can resolve
|
||||||
|
# the auto-generated Requires. Without this, dnf5 silently filters the
|
||||||
|
# package and reports "Nothing to do".
|
||||||
|
Provides: user(cortex)
|
||||||
|
|
||||||
%description
|
%description
|
||||||
Cortex is a Rust reverse-proxy that sits in front of multiple mistral.rs
|
Cortex is a Rust reverse-proxy that sits in front of multiple inference
|
||||||
inference nodes and presents a unified OpenAI and Anthropic compatible
|
nodes (via neuron daemons) and presents a unified OpenAI and Anthropic
|
||||||
API surface. It handles model routing, lifecycle management, request
|
compatible API surface. It handles model routing, lifecycle management,
|
||||||
translation, and metrics collection.
|
request translation, and metrics collection.
|
||||||
|
|
||||||
%prep
|
%prep
|
||||||
%autosetup
|
%autosetup
|
||||||
@@ -38,12 +55,72 @@ cargo build --release -p cortex-cli
|
|||||||
|
|
||||||
%install
|
%install
|
||||||
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
|
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
|
||||||
|
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
|
||||||
|
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
|
||||||
|
install -Dm644 data/cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
|
install -dm755 %{buildroot}%{_sysconfdir}/cortex
|
||||||
|
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
|
||||||
|
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
|
%pre
|
||||||
|
%sysusers_create_compat %{_builddir}/%{name}-%{version}/data/cortex-sysusers.conf
|
||||||
|
|
||||||
|
%post
|
||||||
|
%systemd_post cortex.service
|
||||||
|
|
||||||
|
%preun
|
||||||
|
%systemd_preun cortex.service
|
||||||
|
|
||||||
|
%postun
|
||||||
|
%systemd_postun_with_restart cortex.service
|
||||||
|
|
||||||
|
%posttrans
|
||||||
|
# Migration: older cortex packages shipped the firewalld service as
|
||||||
|
# `helexa-cortex` and (in some build streams) with wrong port numbers
|
||||||
|
# (9301/9302/9304). Operators who enabled that legacy service in their
|
||||||
|
# zone end up with the wrong-port override taking precedence over the
|
||||||
|
# vendor `cortex.xml` now in /usr/lib/firewalld/services/. Clean up the
|
||||||
|
# stale /etc/ override here and migrate any zone bindings to the new
|
||||||
|
# service name.
|
||||||
|
if [ -f /etc/firewalld/services/helexa-cortex.xml ]; then
|
||||||
|
rm -f /etc/firewalld/services/helexa-cortex.xml
|
||||||
|
fi
|
||||||
|
if [ -x /usr/bin/firewall-cmd ] && /usr/bin/firewall-cmd --state >/dev/null 2>&1; then
|
||||||
|
# Drop the legacy service name from every zone where it was enabled
|
||||||
|
# and add the new `cortex` service in its place. Operators who never
|
||||||
|
# ran firewall-cmd against either name see no zone change.
|
||||||
|
for zone in $(/usr/bin/firewall-cmd --get-active-zones 2>/dev/null \
|
||||||
|
| awk '!/^[[:space:]]/ {print $1}'); do
|
||||||
|
if /usr/bin/firewall-cmd --permanent --zone="$zone" --query-service=helexa-cortex >/dev/null 2>&1; then
|
||||||
|
/usr/bin/firewall-cmd --permanent --zone="$zone" --remove-service=helexa-cortex >/dev/null 2>&1 || :
|
||||||
|
/usr/bin/firewall-cmd --permanent --zone="$zone" --add-service=cortex >/dev/null 2>&1 || :
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
/usr/bin/firewall-cmd --reload >/dev/null 2>&1 || :
|
||||||
|
fi
|
||||||
|
:
|
||||||
|
|
||||||
%files
|
%files
|
||||||
%license LICENSE
|
%license LICENSE
|
||||||
%doc README.md
|
%doc README.md
|
||||||
%{_bindir}/cortex
|
%{_bindir}/cortex
|
||||||
|
%{_unitdir}/cortex.service
|
||||||
|
%{_sysusersdir}/cortex.conf
|
||||||
|
%{_prefix}/lib/firewalld/services/cortex.xml
|
||||||
|
%dir %{_sysconfdir}/cortex
|
||||||
|
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
|
||||||
|
%config(noreplace) %{_sysconfdir}/cortex/models.toml
|
||||||
|
|
||||||
%changelog
|
%changelog
|
||||||
* Mon Apr 14 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
|
||||||
|
- chore: ignore local deploy script
|
||||||
|
- chore: move default ports out of common-collision ranges
|
||||||
|
- ci: drop actions/cache for cargo registry and target
|
||||||
|
|
||||||
|
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
|
||||||
|
- ci: publish both packages to a single helexa/helexa COPR project
|
||||||
|
- fix(rpm): rename neuron package to helexa-neuron
|
||||||
|
- ci: commit generated %changelog entries back to main
|
||||||
|
|
||||||
|
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
||||||
- Initial package
|
- Initial package
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "cortex-agent"
|
|
||||||
version.workspace = true
|
|
||||||
edition.workspace = true
|
|
||||||
license.workspace = true
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
cortex-core.workspace = true
|
|
||||||
tokio.workspace = true
|
|
||||||
serde.workspace = true
|
|
||||||
serde_json.workspace = true
|
|
||||||
reqwest.workspace = true
|
|
||||||
tracing.workspace = true
|
|
||||||
anyhow.workspace = true
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
//! Per-node agent sidecar.
|
|
||||||
//!
|
|
||||||
//! This is a future component that runs on each GPU node alongside mistralrs.
|
|
||||||
//! It handles:
|
|
||||||
//! - VRAM defragmentation (restarting the mistralrs systemd unit when the
|
|
||||||
//! gateway signals that lifecycle_cycles has exceeded the threshold)
|
|
||||||
//! - Local nvidia-smi polling for actual VRAM usage reporting
|
|
||||||
//! - Systemd unit management for mistralrs process restarts
|
|
||||||
//!
|
|
||||||
//! For now this is a stub. The gateway's poller + evictor handle the critical
|
|
||||||
//! path (model lifecycle via the mistralrs HTTP API). The agent adds
|
|
||||||
//! operational niceties that can be built incrementally.
|
|
||||||
|
|
||||||
/// Placeholder for agent configuration.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct AgentConfig {
|
|
||||||
/// The local mistralrs endpoint to monitor.
|
|
||||||
pub mistralrs_endpoint: String,
|
|
||||||
/// The systemd unit name for mistralrs (e.g. "mistralrs.service").
|
|
||||||
pub systemd_unit: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Restart the local mistralrs process via systemd.
|
|
||||||
/// This is the nuclear option for VRAM defragmentation.
|
|
||||||
pub async fn restart_mistralrs(config: &AgentConfig) -> anyhow::Result<()> {
|
|
||||||
tracing::warn!(
|
|
||||||
unit = %config.systemd_unit,
|
|
||||||
"restarting mistralrs for VRAM defragmentation"
|
|
||||||
);
|
|
||||||
|
|
||||||
let output = tokio::process::Command::new("systemctl")
|
|
||||||
.args(["restart", &config.systemd_unit])
|
|
||||||
.output()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if output.status.success() {
|
|
||||||
tracing::info!(unit = %config.systemd_unit, "mistralrs restarted successfully");
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
||||||
anyhow::bail!("systemctl restart failed: {stderr}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Query nvidia-smi for current VRAM usage on this node.
|
|
||||||
/// Returns (used_mb, total_mb) for each GPU.
|
|
||||||
pub async fn query_vram() -> anyhow::Result<Vec<(u64, u64)>> {
|
|
||||||
let output = tokio::process::Command::new("nvidia-smi")
|
|
||||||
.args([
|
|
||||||
"--query-gpu=memory.used,memory.total",
|
|
||||||
"--format=csv,noheader,nounits",
|
|
||||||
])
|
|
||||||
.output()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !output.status.success() {
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
|
||||||
anyhow::bail!("nvidia-smi failed: {stderr}");
|
|
||||||
}
|
|
||||||
|
|
||||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
|
||||||
let mut gpus = Vec::new();
|
|
||||||
for line in stdout.lines() {
|
|
||||||
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
|
|
||||||
if parts.len() == 2 {
|
|
||||||
let used: u64 = parts[0].parse().unwrap_or(0);
|
|
||||||
let total: u64 = parts[1].parse().unwrap_or(0);
|
|
||||||
gpus.push((used, total));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(gpus)
|
|
||||||
}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
pub mod agent;
|
|
||||||
@@ -17,4 +17,4 @@ tracing-subscriber.workspace = true
|
|||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
reqwest.workspace = true
|
reqwest.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
clap = { version = "4", features = ["derive"] }
|
clap.workspace = true
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use tracing_subscriber::EnvFilter;
|
|||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "cortex")]
|
#[command(name = "cortex")]
|
||||||
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
|
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
|
||||||
#[command(version)]
|
#[command(version)]
|
||||||
struct Cli {
|
struct Cli {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
@@ -23,7 +23,7 @@ enum Commands {
|
|||||||
/// Print the fleet status (models, nodes, health).
|
/// Print the fleet status (models, nodes, health).
|
||||||
Status {
|
Status {
|
||||||
/// Gateway API endpoint to query.
|
/// Gateway API endpoint to query.
|
||||||
#[arg(short, long, default_value = "http://localhost:8000")]
|
#[arg(short, long, default_value = "http://localhost:31313")]
|
||||||
endpoint: String,
|
endpoint: String,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -46,7 +46,7 @@ async fn main() -> Result<()> {
|
|||||||
.map_err(|e| anyhow::anyhow!("failed to load config from '{config}': {e}"))?;
|
.map_err(|e| anyhow::anyhow!("failed to load config from '{config}': {e}"))?;
|
||||||
|
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
nodes = cfg.nodes.len(),
|
neurons = cfg.neurons.len(),
|
||||||
listen = %cfg.gateway.listen,
|
listen = %cfg.gateway.listen,
|
||||||
"starting cortex"
|
"starting cortex"
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -13,3 +13,4 @@ chrono.workspace = true
|
|||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
thiserror.workspace = true
|
thiserror.workspace = true
|
||||||
tracing.workspace = true
|
tracing.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
//!
|
//!
|
||||||
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
||||||
//! The gateway accepts these, translates to OpenAI format, proxies to
|
//! The gateway accepts these, translates to OpenAI format, proxies to
|
||||||
//! mistral.rs, then translates the response back.
|
//! the inference backend (neuron), then translates the response back.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|||||||
211
crates/cortex-core/src/catalogue.rs
Normal file
211
crates/cortex-core/src/catalogue.rs
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
//! Model catalogue — profiles describing how to serve each model.
|
||||||
|
|
||||||
|
use crate::discovery::DeviceInfo;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// A model serving profile loaded from models.toml.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelProfile {
|
||||||
|
pub id: String,
|
||||||
|
pub harness: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub quant: Option<String>,
|
||||||
|
/// Estimated VRAM usage in MB when loaded.
|
||||||
|
#[serde(default)]
|
||||||
|
pub vram_mb: Option<u64>,
|
||||||
|
/// Minimum number of GPU devices required.
|
||||||
|
#[serde(default = "default_min_devices")]
|
||||||
|
pub min_devices: u32,
|
||||||
|
/// Minimum VRAM per device in MB.
|
||||||
|
#[serde(default)]
|
||||||
|
pub min_device_vram_mb: Option<u64>,
|
||||||
|
/// Neurons where this model should never be evicted.
|
||||||
|
#[serde(default)]
|
||||||
|
pub pinned_on: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_min_devices() -> u32 {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The full model catalogue.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
pub struct ModelCatalogue {
|
||||||
|
#[serde(default)]
|
||||||
|
pub models: Vec<ModelProfile>,
|
||||||
|
/// Tier aliases — clients can send a request with `model: "helexa/small"`
|
||||||
|
/// and the gateway transparently rewrites + routes to the concrete
|
||||||
|
/// model id this maps to. Lets operators define latency/quality
|
||||||
|
/// tiers (`small`/`balanced`/`large`, `fast`/`thinking`, etc.)
|
||||||
|
/// without imposing knowledge of specific model ids on clients.
|
||||||
|
/// Loaded from the `[aliases]` table in models.toml.
|
||||||
|
#[serde(default)]
|
||||||
|
pub aliases: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelCatalogue {
|
||||||
|
/// Load the catalogue from a TOML file. Returns empty catalogue if file doesn't exist.
|
||||||
|
pub fn load(path: impl AsRef<Path>) -> Self {
|
||||||
|
let path = path.as_ref();
|
||||||
|
if !path.exists() {
|
||||||
|
tracing::info!(path = %path.display(), "no model catalogue found, using empty");
|
||||||
|
return Self::default();
|
||||||
|
}
|
||||||
|
match std::fs::read_to_string(path) {
|
||||||
|
Ok(contents) => match toml::from_str(&contents) {
|
||||||
|
Ok(cat) => cat,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(path = %path.display(), error = %e, "failed to parse model catalogue");
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(path = %path.display(), error = %e, "failed to read model catalogue");
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a model is pinned on a given neuron.
|
||||||
|
pub fn is_pinned(&self, model_id: &str, neuron_name: &str) -> bool {
|
||||||
|
self.models
|
||||||
|
.iter()
|
||||||
|
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find a profile by model id.
|
||||||
|
pub fn get(&self, model_id: &str) -> Option<&ModelProfile> {
|
||||||
|
self.models.iter().find(|p| p.id == model_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve an alias to its concrete model id. Returns `id` verbatim
|
||||||
|
/// when it isn't an alias. Aliases never chain — operator config
|
||||||
|
/// is treated as flat — so this is a single lookup.
|
||||||
|
pub fn resolve_alias<'a>(&'a self, id: &'a str) -> &'a str {
|
||||||
|
self.aliases.get(id).map(String::as_str).unwrap_or(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelProfile {
|
||||||
|
/// True iff this profile's placement constraints can be satisfied
|
||||||
|
/// by the named neuron with the given device topology.
|
||||||
|
///
|
||||||
|
/// Constraints checked:
|
||||||
|
/// - `pinned_on`: non-empty → neuron must be on the list.
|
||||||
|
/// - `min_devices`: neuron must have at least this many devices.
|
||||||
|
/// - `min_device_vram_mb`: at least `min_devices` of the neuron's
|
||||||
|
/// devices must each meet this VRAM floor.
|
||||||
|
pub fn is_feasible_on(&self, neuron_name: &str, devices: &[DeviceInfo]) -> bool {
|
||||||
|
if !self.pinned_on.is_empty() && !self.pinned_on.iter().any(|n| n == neuron_name) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (devices.len() as u32) < self.min_devices {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if let Some(min_vram) = self.min_device_vram_mb {
|
||||||
|
let big_enough = devices
|
||||||
|
.iter()
|
||||||
|
.filter(|d| d.vram_total_mb >= min_vram)
|
||||||
|
.count() as u32;
|
||||||
|
if big_enough < self.min_devices {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::discovery::DeviceInfo;
|
||||||
|
|
||||||
|
fn device(idx: u32, vram_mb: u64) -> DeviceInfo {
|
||||||
|
DeviceInfo {
|
||||||
|
index: idx,
|
||||||
|
name: format!("DEV-{idx}"),
|
||||||
|
vram_total_mb: vram_mb,
|
||||||
|
compute_capability: "8.6".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn profile() -> ModelProfile {
|
||||||
|
ModelProfile {
|
||||||
|
id: "Qwen/Qwen3.6-27B".into(),
|
||||||
|
harness: "candle".into(),
|
||||||
|
quant: None,
|
||||||
|
vram_mb: Some(45_000),
|
||||||
|
min_devices: 2,
|
||||||
|
min_device_vram_mb: Some(24_000),
|
||||||
|
pinned_on: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn feasible_when_two_devices_meet_vram_floor() {
|
||||||
|
let p = profile();
|
||||||
|
let devices = [device(0, 32_000), device(1, 32_000)];
|
||||||
|
assert!(p.is_feasible_on("beast", &devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infeasible_when_only_one_device() {
|
||||||
|
let p = profile();
|
||||||
|
let devices = [device(0, 64_000)];
|
||||||
|
assert!(!p.is_feasible_on("benjy", &devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infeasible_when_one_device_underspec() {
|
||||||
|
let p = profile();
|
||||||
|
let devices = [device(0, 32_000), device(1, 12_000)];
|
||||||
|
assert!(!p.is_feasible_on("mixed", &devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pinned_on_excludes_other_neurons() {
|
||||||
|
let mut p = profile();
|
||||||
|
p.pinned_on = vec!["beast".into()];
|
||||||
|
let devices = [device(0, 32_000), device(1, 32_000)];
|
||||||
|
assert!(p.is_feasible_on("beast", &devices));
|
||||||
|
assert!(!p.is_feasible_on("benjy", &devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_vram_floor_just_needs_min_devices() {
|
||||||
|
let mut p = profile();
|
||||||
|
p.min_device_vram_mb = None;
|
||||||
|
let devices = [device(0, 1_000), device(1, 1_000)];
|
||||||
|
assert!(p.is_feasible_on("anywhere", &devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_alias_returns_target_when_alias_present() {
|
||||||
|
let mut cat = ModelCatalogue::default();
|
||||||
|
cat.aliases
|
||||||
|
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
|
||||||
|
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_alias_passes_through_when_not_an_alias() {
|
||||||
|
let mut cat = ModelCatalogue::default();
|
||||||
|
cat.aliases
|
||||||
|
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
|
||||||
|
assert_eq!(cat.resolve_alias("Qwen/Qwen3-8B"), "Qwen/Qwen3-8B");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn aliases_table_round_trips_through_toml() {
|
||||||
|
let src = r#"
|
||||||
|
[aliases]
|
||||||
|
"helexa/small" = "Qwen/Qwen3-1.7B"
|
||||||
|
"helexa/large" = "Qwen/Qwen3.6-27B"
|
||||||
|
"#;
|
||||||
|
let cat: ModelCatalogue = toml::from_str(src).expect("parse aliases table");
|
||||||
|
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
|
||||||
|
assert_eq!(cat.resolve_alias("helexa/large"), "Qwen/Qwen3.6-27B");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,14 +9,22 @@ use std::path::Path;
|
|||||||
pub struct GatewayConfig {
|
pub struct GatewayConfig {
|
||||||
pub gateway: GatewaySettings,
|
pub gateway: GatewaySettings,
|
||||||
pub eviction: EvictionSettings,
|
pub eviction: EvictionSettings,
|
||||||
pub nodes: Vec<NodeConfig>,
|
/// Neuron endpoints (replaces old NodeConfig with static vram_mb/pinned).
|
||||||
|
pub neurons: Vec<NeuronEndpoint>,
|
||||||
|
/// Path to the model catalogue file (default: "models.toml").
|
||||||
|
#[serde(default = "default_models_path")]
|
||||||
|
pub models_config: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_models_path() -> String {
|
||||||
|
"models.toml".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct GatewaySettings {
|
pub struct GatewaySettings {
|
||||||
/// Address to listen on for API requests (e.g. "0.0.0.0:8000")
|
/// Address to listen on for API requests (e.g. "0.0.0.0:31313")
|
||||||
pub listen: String,
|
pub listen: String,
|
||||||
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:9100")
|
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:31314")
|
||||||
pub metrics_listen: String,
|
pub metrics_listen: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,8 +32,7 @@ pub struct GatewaySettings {
|
|||||||
pub struct EvictionSettings {
|
pub struct EvictionSettings {
|
||||||
/// Eviction strategy: "lru" or "priority"
|
/// Eviction strategy: "lru" or "priority"
|
||||||
pub strategy: EvictionStrategy,
|
pub strategy: EvictionStrategy,
|
||||||
/// Restart the mistralrs process after this many load/unload cycles
|
/// Number of load/unload cycles before flagging for defrag. 0 = never.
|
||||||
/// to reclaim fragmented VRAM. 0 = never.
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub defrag_after_cycles: u32,
|
pub defrag_after_cycles: u32,
|
||||||
}
|
}
|
||||||
@@ -37,23 +44,19 @@ pub enum EvictionStrategy {
|
|||||||
Priority,
|
Priority,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A neuron endpoint in the fleet. Hardware details come from
|
||||||
|
/// neuron's /discovery endpoint, not from config.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct NodeConfig {
|
pub struct NeuronEndpoint {
|
||||||
/// Human-readable node name (e.g. "gpu-large")
|
/// Human-readable node name (e.g. "beast")
|
||||||
pub name: String,
|
pub name: String,
|
||||||
/// Base URL of the mistralrs HTTP server (e.g. "http://gpu-large.internal:8080")
|
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131")
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
/// Total VRAM in MB across all GPUs on this node
|
|
||||||
pub vram_mb: u64,
|
|
||||||
/// Model IDs that should never be evicted from this node
|
|
||||||
#[serde(default)]
|
|
||||||
pub pinned: Vec<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GatewayConfig {
|
impl GatewayConfig {
|
||||||
/// Load configuration from a TOML file, with environment variable overrides.
|
/// Load configuration from a TOML file, with environment variable overrides.
|
||||||
/// Env vars are prefixed with `CORTEX_` and use `__` as a separator
|
/// Env vars are prefixed with `CORTEX_` and use `__` as a separator.
|
||||||
/// (e.g. `CORTEX_GATEWAY__LISTEN=0.0.0.0:9000`).
|
|
||||||
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<figment::Error>> {
|
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<figment::Error>> {
|
||||||
Figment::new()
|
Figment::new()
|
||||||
.merge(Toml::file(path))
|
.merge(Toml::file(path))
|
||||||
@@ -67,14 +70,15 @@ impl Default for GatewayConfig {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
gateway: GatewaySettings {
|
gateway: GatewaySettings {
|
||||||
listen: "0.0.0.0:8000".into(),
|
listen: "0.0.0.0:31313".into(),
|
||||||
metrics_listen: "0.0.0.0:9100".into(),
|
metrics_listen: "0.0.0.0:31314".into(),
|
||||||
},
|
},
|
||||||
eviction: EvictionSettings {
|
eviction: EvictionSettings {
|
||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 50,
|
defrag_after_cycles: 50,
|
||||||
},
|
},
|
||||||
nodes: vec![],
|
neurons: vec![],
|
||||||
|
models_config: default_models_path(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
107
crates/cortex-core/src/discovery.rs
Normal file
107
crates/cortex-core/src/discovery.rs
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
//! Hardware discovery and health types shared between cortex and neuron.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Information about a single GPU device discovered on a node.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DeviceInfo {
|
||||||
|
pub index: u32,
|
||||||
|
pub name: String,
|
||||||
|
pub vram_total_mb: u64,
|
||||||
|
pub compute_capability: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Full discovery response from a neuron endpoint.
|
||||||
|
/// Returned by `GET /discovery`.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DiscoveryResponse {
|
||||||
|
pub hostname: String,
|
||||||
|
pub os: String,
|
||||||
|
pub kernel: String,
|
||||||
|
pub cuda_version: Option<String>,
|
||||||
|
pub driver_version: Option<String>,
|
||||||
|
pub devices: Vec<DeviceInfo>,
|
||||||
|
pub harnesses: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runtime health metrics for a single GPU device.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DeviceHealth {
|
||||||
|
pub index: u32,
|
||||||
|
pub vram_used_mb: u64,
|
||||||
|
pub vram_free_mb: u64,
|
||||||
|
pub utilization_pct: u32,
|
||||||
|
pub temp_c: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runtime health response from a neuron endpoint.
|
||||||
|
/// Returned by `GET /health`.
|
||||||
|
///
|
||||||
|
/// `activation` was added in 2026-05-26 to distinguish "process is up
|
||||||
|
/// and reachable" from "process is ready to serve traffic". A `Type=simple`
|
||||||
|
/// systemd unit reports `active` the moment the binary starts — but a
|
||||||
|
/// neuron whose `default_models` list takes minutes to materialise
|
||||||
|
/// won't bind its listener (or, in the new flow, won't have any models
|
||||||
|
/// loaded) until pre-warm completes. The new field is `#[serde(default)]`
|
||||||
|
/// so a pre-2026-05-26 gateway polling a new neuron — or vice versa —
|
||||||
|
/// keeps working.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct HealthResponse {
|
||||||
|
pub uptime_secs: u64,
|
||||||
|
pub devices: Vec<DeviceHealth>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub activation: ActivationStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// High-level activation state of the neuron daemon. The HTTP listener
|
||||||
|
/// is bound during both states; what differs is whether the configured
|
||||||
|
/// `default_models` have finished loading.
|
||||||
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ActivationState {
|
||||||
|
/// At least one `default_models` entry is still loading. The
|
||||||
|
/// neuron's other endpoints work, but inference against
|
||||||
|
/// not-yet-loaded models will 404.
|
||||||
|
PreWarming,
|
||||||
|
/// Every `default_models` entry has either loaded or failed; the
|
||||||
|
/// neuron is steady-state. Subsequent on-demand loads via
|
||||||
|
/// `/models/load` don't flip back to PreWarming — that field
|
||||||
|
/// reflects the activation-time set only.
|
||||||
|
#[default]
|
||||||
|
Ready,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-model failure record surfaced in [`ActivationStatus::failed`].
|
||||||
|
/// The error string is the rendered anyhow chain at the time of the
|
||||||
|
/// failure; operators read it from `/health` to decide whether to
|
||||||
|
/// retry, edit the spec, or unload+reload.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PreWarmFailure {
|
||||||
|
pub model_id: String,
|
||||||
|
pub error: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Activation-time progress snapshot. All four lists are populated by
|
||||||
|
/// the neuron's pre-warm task and read by the `/health` handler. The
|
||||||
|
/// snapshot is consistent: a model id appears in exactly one of
|
||||||
|
/// `pending`, `in_progress` (as `Option<String>`), `completed`, or
|
||||||
|
/// `failed` at any point in time.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
pub struct ActivationStatus {
|
||||||
|
pub state: ActivationState,
|
||||||
|
/// Model ids queued but not yet started. Empty in `Ready` state.
|
||||||
|
#[serde(default)]
|
||||||
|
pub pending: Vec<String>,
|
||||||
|
/// Model id currently materialising. None when between models or
|
||||||
|
/// in `Ready` state.
|
||||||
|
#[serde(default)]
|
||||||
|
pub in_progress: Option<String>,
|
||||||
|
/// Model ids that finished loading successfully during this
|
||||||
|
/// activation. Cleared on process restart.
|
||||||
|
#[serde(default)]
|
||||||
|
pub completed: Vec<String>,
|
||||||
|
/// Model ids that failed during this activation, with the rendered
|
||||||
|
/// error chain. Cleared on process restart.
|
||||||
|
#[serde(default)]
|
||||||
|
pub failed: Vec<PreWarmFailure>,
|
||||||
|
}
|
||||||
84
crates/cortex-core/src/harness.rs
Normal file
84
crates/cortex-core/src/harness.rs
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
//! Harness trait and supporting types for inference engine management.
|
||||||
|
//!
|
||||||
|
//! Defined in cortex-core so both cortex (control plane) and neuron
|
||||||
|
//! (node plane) share the type definitions. neuron provides the
|
||||||
|
//! runtime implementations.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Configuration for a harness instance on a neuron.
|
||||||
|
///
|
||||||
|
/// All current harnesses are in-process (candle); per-harness tuning
|
||||||
|
/// (cache paths, device policies, etc.) lives in dedicated config
|
||||||
|
/// blocks rather than on this struct.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct HarnessConfig {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Health status of a harness process.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct HarnessHealth {
|
||||||
|
pub name: String,
|
||||||
|
pub running: bool,
|
||||||
|
pub uptime_secs: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Specification for loading a model through a harness.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelSpec {
|
||||||
|
pub model_id: String,
|
||||||
|
pub harness: String,
|
||||||
|
pub quant: Option<String>,
|
||||||
|
pub tensor_parallel: Option<u32>,
|
||||||
|
pub devices: Option<Vec<u32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A model as reported by a harness.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub harness: String,
|
||||||
|
pub status: String,
|
||||||
|
pub devices: Vec<u32>,
|
||||||
|
pub vram_used_mb: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// What an inference harness must do, from neuron's perspective.
|
||||||
|
///
|
||||||
|
/// All current harnesses are in-process — they share neuron's address
|
||||||
|
/// space and lifecycle. `start`/`stop` therefore default to no-ops; a
|
||||||
|
/// future process-supervising harness would override them.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Harness: Send + Sync {
|
||||||
|
/// Human-readable name (e.g. "candle").
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// Start the harness. Default no-op for in-process harnesses.
|
||||||
|
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stop the harness. Default no-op for in-process harnesses.
|
||||||
|
async fn stop(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Health check. Returns the harness process status.
|
||||||
|
async fn health(&self) -> HarnessHealth;
|
||||||
|
|
||||||
|
/// List models the harness knows about (loaded + unloaded).
|
||||||
|
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||||
|
|
||||||
|
/// Load a model with the given spec (quant, TP, device assignment).
|
||||||
|
async fn load_model(&self, spec: &ModelSpec) -> Result<()>;
|
||||||
|
|
||||||
|
/// Unload a model, freeing device memory.
|
||||||
|
async fn unload_model(&self, model_id: &str) -> Result<()>;
|
||||||
|
|
||||||
|
/// Return the URL where inference requests for this model should
|
||||||
|
/// be sent. None if the model is not loaded.
|
||||||
|
async fn inference_endpoint(&self, model_id: &str) -> Option<String>;
|
||||||
|
}
|
||||||
@@ -1,6 +1,10 @@
|
|||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
|
pub mod catalogue;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod discovery;
|
||||||
|
pub mod harness;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
pub mod node;
|
pub mod node;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
|
pub mod responses;
|
||||||
pub mod translate;
|
pub mod translate;
|
||||||
|
|||||||
@@ -1,19 +1,31 @@
|
|||||||
|
use crate::discovery::{ActivationStatus, DiscoveryResponse};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
/// Runtime state of a single node in the fleet.
|
/// Runtime state of a single neuron in the fleet.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct NodeState {
|
pub struct NodeState {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131").
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
pub vram_mb: u64,
|
|
||||||
pub pinned: Vec<String>,
|
|
||||||
pub healthy: bool,
|
pub healthy: bool,
|
||||||
pub models: HashMap<String, ModelEntry>,
|
pub models: HashMap<String, ModelEntry>,
|
||||||
/// Number of load/unload cycles since last process restart.
|
/// Number of load/unload cycles since last process restart.
|
||||||
pub lifecycle_cycles: u32,
|
pub lifecycle_cycles: u32,
|
||||||
pub last_poll: Option<DateTime<Utc>>,
|
pub last_poll: Option<DateTime<Utc>>,
|
||||||
|
/// Result of the most recent successful `GET /discovery` against
|
||||||
|
/// this neuron. Cached forever once obtained — device topology is
|
||||||
|
/// invariant for a given neuron process. `None` until the first
|
||||||
|
/// successful poll. Used by the router and `/v1/models` to do
|
||||||
|
/// catalogue × topology feasibility checks.
|
||||||
|
pub discovery: Option<DiscoveryResponse>,
|
||||||
|
/// Last-seen pre-warm progress from this neuron's `/health`
|
||||||
|
/// endpoint. `None` until the first /health poll succeeds. The
|
||||||
|
/// `/v1/models` handler reads `in_progress` + `pending` from here
|
||||||
|
/// to synthesize `Loading` locations so clients see a catalogued
|
||||||
|
/// model that's mid-prewarm as "loading", not "missing".
|
||||||
|
pub activation: Option<ActivationStatus>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A model registered on a node, with its runtime status.
|
/// A model registered on a node, with its runtime status.
|
||||||
@@ -27,22 +39,51 @@ pub struct ModelEntry {
|
|||||||
pub vram_estimate_mb: Option<u64>,
|
pub vram_estimate_mb: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Model lifecycle status, matching the mistral.rs API.
|
/// Model lifecycle status.
|
||||||
|
///
|
||||||
|
/// `Loading` is a gateway-side synthetic status: neurons never emit it
|
||||||
|
/// on `/models` (that endpoint only knows about already-loaded handles).
|
||||||
|
/// The gateway populates it from a neuron's `/health` activation
|
||||||
|
/// snapshot so the unified `/v1/models` can distinguish "model is
|
||||||
|
/// catalogued but no one has it" from "model is materialising on
|
||||||
|
/// neuron N right now". Other status values are reported verbatim by
|
||||||
|
/// neurons.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum ModelStatus {
|
pub enum ModelStatus {
|
||||||
Loaded,
|
Loaded,
|
||||||
Unloaded,
|
Unloaded,
|
||||||
Reloading,
|
Reloading,
|
||||||
|
Loading,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
|
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
|
||||||
/// Includes which node(s) host this model and their status.
|
///
|
||||||
|
/// The first four fields (`id`, `object`, `created`, `owned_by`) match
|
||||||
|
/// OpenAI's `/v1/models` shape verbatim, so existing OpenAI-aware
|
||||||
|
/// tooling deserialises this without custom code. The remaining fields
|
||||||
|
/// are helexa-specific extensions — OpenAI clients ignore unknown
|
||||||
|
/// fields and other consumers can read them for placement / debugging.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct CortexModelEntry {
|
pub struct CortexModelEntry {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
/// Always `"model"` per OpenAI's contract.
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Which nodes have this model (and their status).
|
/// Unix-second timestamp; cortex stamps this at response time.
|
||||||
|
pub created: u64,
|
||||||
|
/// OpenAI's "publisher" field — `"helexa"` for everything we serve.
|
||||||
|
pub owned_by: String,
|
||||||
|
/// True if any neuron currently has this model loaded. False for
|
||||||
|
/// catalogue entries that are feasible but not yet loaded.
|
||||||
|
pub loaded: bool,
|
||||||
|
/// Neurons whose discovered topology can satisfy this model's
|
||||||
|
/// catalogue placement constraints. Empty for models that are
|
||||||
|
/// loaded somewhere but not present in the catalogue (cortex has
|
||||||
|
/// no feasibility opinion on those).
|
||||||
|
pub feasible_on: Vec<String>,
|
||||||
|
/// Where this model is actually loaded right now. Subset of (or
|
||||||
|
/// disjoint from) `feasible_on` depending on whether the catalogue
|
||||||
|
/// covers this model.
|
||||||
pub locations: Vec<ModelLocation>,
|
pub locations: Vec<ModelLocation>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,23 +93,3 @@ pub struct ModelLocation {
|
|||||||
pub status: ModelStatus,
|
pub status: ModelStatus,
|
||||||
pub vram_estimate_mb: Option<u64>,
|
pub vram_estimate_mb: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Response from mistral.rs `GET /v1/models`.
|
|
||||||
/// This is the upstream format we parse when polling nodes.
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct MistralModelsResponse {
|
|
||||||
pub data: Vec<MistralModelEntry>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct MistralModelEntry {
|
|
||||||
pub id: String,
|
|
||||||
#[serde(default)]
|
|
||||||
pub status: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Request body for mistral.rs model lifecycle endpoints.
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
pub struct ModelLifecycleRequest {
|
|
||||||
pub model_id: String,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
||||||
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
||||||
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
||||||
//! extension field mistral.rs supports.
|
//! extension field a backend might support.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -22,7 +22,7 @@ pub struct ChatCompletionRequest {
|
|||||||
pub max_tokens: Option<u64>,
|
pub max_tokens: Option<u64>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stream: Option<bool>,
|
pub stream: Option<bool>,
|
||||||
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
|
/// All other fields (tools, response_format, backend extensions, etc.)
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub extra: Value,
|
pub extra: Value,
|
||||||
}
|
}
|
||||||
|
|||||||
346
crates/cortex-core/src/responses.rs
Normal file
346
crates/cortex-core/src/responses.rs
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
//! OpenAI Responses API (`POST /v1/responses`) envelope types.
|
||||||
|
//!
|
||||||
|
//! This is OpenAI's newer chat surface, distinct from
|
||||||
|
//! `/v1/chat/completions` in three ways that matter for us:
|
||||||
|
//!
|
||||||
|
//! 1. **Input shape**. Instead of a `messages` array, the request
|
||||||
|
//! carries `input` — either a plain string (single user turn)
|
||||||
|
//! or an array of typed items (messages, function calls,
|
||||||
|
//! function-call outputs, reasoning blocks, …).
|
||||||
|
//! 2. **Output shape**. The response carries a single `output`
|
||||||
|
//! array of items, each typed. We always emit one
|
||||||
|
//! `OutputItem::Message` containing the assistant's reply (plus,
|
||||||
|
//! when we get there, separate `function_call` items).
|
||||||
|
//! 3. **Streaming events**. Where chat completions stream
|
||||||
|
//! structurally-identical `chat.completion.chunk` frames over
|
||||||
|
//! `data:` lines, Responses streams *named* events
|
||||||
|
//! (`response.created`, `response.output_text.delta`,
|
||||||
|
//! `response.completed`, …) over `event:` + `data:` SSE pairs.
|
||||||
|
//! The wire projector in `neuron::wire::openai_responses` builds
|
||||||
|
//! these from the same [`crate::openai`]-shaped
|
||||||
|
//! `InferenceEvent` stream the chat projector consumes.
|
||||||
|
//!
|
||||||
|
//! Scope cuts for this first cut:
|
||||||
|
//!
|
||||||
|
//! - **`previous_response_id` is rejected at parse time**. Stateful
|
||||||
|
//! chained conversations need a persistence layer we don't have.
|
||||||
|
//! - **Reasoning items are accepted-and-ignored** (no Qwen3
|
||||||
|
//! `<think>` routing yet). Audio and embedded resources are
|
||||||
|
//! rejected as unsupported.
|
||||||
|
//! - **Tool calls** (function_call / function_call_output) are
|
||||||
|
//! carried as round-trip types but the candle harness doesn't
|
||||||
|
//! emit them yet — wired so the surface is in place for the
|
||||||
|
//! day we add proper tool-call extraction.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// ── Request ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Body of a `POST /v1/responses` request.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ResponsesRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub input: ResponsesInput,
|
||||||
|
/// System-prompt-style instructions. The Responses API
|
||||||
|
/// separates these from input so a caller doesn't have to
|
||||||
|
/// build a `system` message item by hand.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_output_tokens: Option<u64>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
/// Chained-conversation identifier. We don't store responses
|
||||||
|
/// server-side yet; if this is `Some`, the handler returns 400.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub previous_response_id: Option<String>,
|
||||||
|
/// Catch-all for anything we don't model yet (tools, tool_choice,
|
||||||
|
/// reasoning, response_format, …). Lets a client send a
|
||||||
|
/// forward-compatible request without our parser rejecting it.
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `input` is either a single string or an array of typed items.
|
||||||
|
/// `#[serde(untagged)]` so the wire shape `"input": "hi"` and
|
||||||
|
/// `"input": [{...}]` both deserialize.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ResponsesInput {
|
||||||
|
Text(String),
|
||||||
|
Items(Vec<ResponsesInputItem>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ResponsesInputItem {
|
||||||
|
/// A user / assistant / system turn.
|
||||||
|
Message {
|
||||||
|
role: String,
|
||||||
|
content: ResponsesMessageContent,
|
||||||
|
},
|
||||||
|
/// Assistant emitted a tool call. Round-trip only — neuron
|
||||||
|
/// doesn't synthesise these yet.
|
||||||
|
FunctionCall {
|
||||||
|
call_id: String,
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
},
|
||||||
|
/// User is feeding a tool result back into the model.
|
||||||
|
FunctionCallOutput { call_id: String, output: String },
|
||||||
|
/// Reasoning items emitted by o-series models. Accepted but
|
||||||
|
/// not forwarded to the model — neuron's candle path doesn't
|
||||||
|
/// surface reasoning separately yet.
|
||||||
|
Reasoning {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Vec<Value>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inside a `Message` item, content is either a plain string or an
|
||||||
|
/// array of typed parts. Mirrors the chat-completions Parts shape.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ResponsesMessageContent {
|
||||||
|
Text(String),
|
||||||
|
Parts(Vec<ResponsesContentPart>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ResponsesContentPart {
|
||||||
|
/// Plain text inside a user / system turn.
|
||||||
|
InputText { text: String },
|
||||||
|
/// An image. `image_url` is either a remote URL or a
|
||||||
|
/// `data:image/png;base64,…` URI; the request translator just
|
||||||
|
/// forwards the string.
|
||||||
|
InputImage {
|
||||||
|
image_url: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
detail: Option<String>,
|
||||||
|
},
|
||||||
|
/// Returned text inside an assistant turn — only relevant when
|
||||||
|
/// the caller is feeding an assistant turn back in to continue
|
||||||
|
/// a conversation manually (no `previous_response_id`).
|
||||||
|
OutputText {
|
||||||
|
text: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
|
annotations: Vec<Value>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Response (non-streaming) ─────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Body of a `POST /v1/responses` response.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ResponsesResponse {
|
||||||
|
pub id: String,
|
||||||
|
/// Always `"response"`.
|
||||||
|
pub object: String,
|
||||||
|
pub created_at: u64,
|
||||||
|
/// `"completed"`, `"incomplete"`, or — for the initial event of
|
||||||
|
/// a streaming response — `"in_progress"`.
|
||||||
|
pub status: String,
|
||||||
|
pub model: String,
|
||||||
|
pub output: Vec<ResponsesOutputItem>,
|
||||||
|
/// Populated on completion; `None` while streaming.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<ResponsesUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ResponsesOutputItem {
|
||||||
|
Message {
|
||||||
|
id: String,
|
||||||
|
/// Always `"assistant"` for model output.
|
||||||
|
role: String,
|
||||||
|
/// Output content parts. We always emit a single
|
||||||
|
/// `OutputText` today; multi-part output would land here
|
||||||
|
/// once we have e.g. image generation.
|
||||||
|
content: Vec<ResponsesOutputContent>,
|
||||||
|
/// Item-level status. `"in_progress"` while streaming the
|
||||||
|
/// content parts, `"completed"` when done.
|
||||||
|
#[serde(default = "default_item_status")]
|
||||||
|
status: String,
|
||||||
|
},
|
||||||
|
/// Reserved for the day tool-call extraction lands. The wire
|
||||||
|
/// shape mirrors `ResponsesInputItem::FunctionCall`.
|
||||||
|
FunctionCall {
|
||||||
|
id: String,
|
||||||
|
call_id: String,
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
#[serde(default = "default_item_status")]
|
||||||
|
status: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_item_status() -> String {
|
||||||
|
"completed".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ResponsesOutputContent {
|
||||||
|
OutputText {
|
||||||
|
text: String,
|
||||||
|
/// Citations / inline annotations. Empty today; reserved
|
||||||
|
/// for the day we wire in web search / file search.
|
||||||
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
|
annotations: Vec<Value>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ResponsesUsage {
|
||||||
|
pub input_tokens: u64,
|
||||||
|
pub output_tokens: u64,
|
||||||
|
pub total_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Streaming event names ────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Event names the SSE projector emits, hoisted as constants so
|
||||||
|
/// the projector and the wire shape stay in sync without
|
||||||
|
/// string-typos. The strings are dictated by OpenAI's published
|
||||||
|
/// Responses API.
|
||||||
|
pub mod events {
|
||||||
|
pub const CREATED: &str = "response.created";
|
||||||
|
/// Fired between `response.created` and the first output-item
|
||||||
|
/// event. Marks "request validated, model is generating" —
|
||||||
|
/// some clients use it to differentiate the "warming up" state
|
||||||
|
/// from "streaming tokens" in their UI.
|
||||||
|
pub const IN_PROGRESS: &str = "response.in_progress";
|
||||||
|
pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added";
|
||||||
|
pub const CONTENT_PART_ADDED: &str = "response.content_part.added";
|
||||||
|
pub const OUTPUT_TEXT_DELTA: &str = "response.output_text.delta";
|
||||||
|
pub const OUTPUT_TEXT_DONE: &str = "response.output_text.done";
|
||||||
|
pub const CONTENT_PART_DONE: &str = "response.content_part.done";
|
||||||
|
pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done";
|
||||||
|
pub const COMPLETED: &str = "response.completed";
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn deserialises_input_string_form() {
|
||||||
|
let raw = r#"{"model": "m", "input": "hello"}"#;
|
||||||
|
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||||
|
match req.input {
|
||||||
|
ResponsesInput::Text(s) => assert_eq!(s, "hello"),
|
||||||
|
other => panic!("expected Text, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn deserialises_input_items_form() {
|
||||||
|
let raw = r#"{
|
||||||
|
"model": "m",
|
||||||
|
"input": [
|
||||||
|
{"type": "message", "role": "user", "content": "hi"}
|
||||||
|
]
|
||||||
|
}"#;
|
||||||
|
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||||
|
match req.input {
|
||||||
|
ResponsesInput::Items(items) => {
|
||||||
|
assert_eq!(items.len(), 1);
|
||||||
|
match &items[0] {
|
||||||
|
ResponsesInputItem::Message { role, content } => {
|
||||||
|
assert_eq!(role, "user");
|
||||||
|
match content {
|
||||||
|
ResponsesMessageContent::Text(t) => assert_eq!(t, "hi"),
|
||||||
|
other => panic!("expected Text content, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other => panic!("expected Message item, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other => panic!("expected Items, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn deserialises_input_with_image() {
|
||||||
|
let raw = r#"{
|
||||||
|
"model": "m",
|
||||||
|
"input": [
|
||||||
|
{"type": "message", "role": "user", "content": [
|
||||||
|
{"type": "input_text", "text": "what is this"},
|
||||||
|
{"type": "input_image", "image_url": "data:image/png;base64,AAA="}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
}"#;
|
||||||
|
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||||
|
let items = match req.input {
|
||||||
|
ResponsesInput::Items(i) => i,
|
||||||
|
other => panic!("expected Items, got {other:?}"),
|
||||||
|
};
|
||||||
|
let parts = match &items[0] {
|
||||||
|
ResponsesInputItem::Message {
|
||||||
|
content: ResponsesMessageContent::Parts(p),
|
||||||
|
..
|
||||||
|
} => p,
|
||||||
|
other => panic!("expected Parts, got {other:?}"),
|
||||||
|
};
|
||||||
|
assert_eq!(parts.len(), 2);
|
||||||
|
assert!(matches!(
|
||||||
|
&parts[0],
|
||||||
|
ResponsesContentPart::InputText { text } if text == "what is this"
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&parts[1],
|
||||||
|
ResponsesContentPart::InputImage { image_url, .. }
|
||||||
|
if image_url == "data:image/png;base64,AAA="
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unknown_fields_round_trip_via_extra() {
|
||||||
|
let raw = r#"{
|
||||||
|
"model": "m",
|
||||||
|
"input": "hi",
|
||||||
|
"tools": [{"type": "web_search"}],
|
||||||
|
"reasoning": {"effort": "medium"}
|
||||||
|
}"#;
|
||||||
|
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||||
|
assert!(req.extra.get("tools").is_some());
|
||||||
|
assert!(req.extra.get("reasoning").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_round_trips_through_serde() {
|
||||||
|
let r = ResponsesResponse {
|
||||||
|
id: "resp_1".into(),
|
||||||
|
object: "response".into(),
|
||||||
|
created_at: 1700,
|
||||||
|
status: "completed".into(),
|
||||||
|
model: "m".into(),
|
||||||
|
output: vec![ResponsesOutputItem::Message {
|
||||||
|
id: "msg_1".into(),
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: vec![ResponsesOutputContent::OutputText {
|
||||||
|
text: "hi there".into(),
|
||||||
|
annotations: vec![],
|
||||||
|
}],
|
||||||
|
status: "completed".into(),
|
||||||
|
}],
|
||||||
|
usage: Some(ResponsesUsage {
|
||||||
|
input_tokens: 5,
|
||||||
|
output_tokens: 3,
|
||||||
|
total_tokens: 8,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&r).unwrap();
|
||||||
|
let parsed: ResponsesResponse = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(parsed.id, "resp_1");
|
||||||
|
assert_eq!(parsed.output.len(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,6 +23,8 @@ futures.workspace = true
|
|||||||
tokio-stream.workspace = true
|
tokio-stream.workspace = true
|
||||||
eventsource-stream.workspace = true
|
eventsource-stream.workspace = true
|
||||||
bytes = "1"
|
bytes = "1"
|
||||||
|
urlencoding = "2"
|
||||||
|
url = "2"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { workspace = true, features = ["test-util"] }
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
|
|||||||
@@ -1,29 +1,19 @@
|
|||||||
//! Model eviction logic.
|
//! Model eviction logic.
|
||||||
//!
|
//!
|
||||||
//! The evictor runs as a background task. When the router determines that a
|
//! The evictor identifies the LRU model on a node (excluding pinned models),
|
||||||
//! model needs to be loaded on a node but VRAM is tight, it can request
|
//! calls neuron's `POST /models/unload` to free the model, and updates
|
||||||
//! eviction via a channel. The evictor then:
|
//! local state.
|
||||||
//! 1. Identifies the LRU model on that node (excluding pinned models)
|
|
||||||
//! 2. Calls `POST /v1/models/unload` on the node
|
|
||||||
//! 3. Increments the lifecycle cycle counter (for defrag tracking)
|
|
||||||
|
|
||||||
use crate::state::CortexState;
|
use crate::state::CortexState;
|
||||||
use cortex_core::node::{ModelLifecycleRequest, ModelStatus};
|
use cortex_core::node::ModelStatus;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
/// Runs forever. Currently a placeholder that periodically checks for
|
/// Runs forever. Placeholder for future channel-driven eviction.
|
||||||
/// eviction opportunities. In the future, this will be driven by a
|
|
||||||
/// channel from the router when VRAM pressure is detected.
|
|
||||||
pub async fn eviction_loop(fleet: Arc<CortexState>) {
|
pub async fn eviction_loop(fleet: Arc<CortexState>) {
|
||||||
// TODO: Replace this polling approach with a channel-driven design
|
|
||||||
// where the router sends eviction requests when it detects that a
|
|
||||||
// model load would exceed available VRAM.
|
|
||||||
loop {
|
loop {
|
||||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
// Placeholder: the actual eviction logic is in `evict_lru_on_node`,
|
let _ = &fleet;
|
||||||
// called on demand by the router.
|
|
||||||
let _ = &fleet; // suppress unused warning
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,18 +23,19 @@ pub async fn evict_lru_on_node(
|
|||||||
fleet: &CortexState,
|
fleet: &CortexState,
|
||||||
node_name: &str,
|
node_name: &str,
|
||||||
) -> anyhow::Result<Option<String>> {
|
) -> anyhow::Result<Option<String>> {
|
||||||
let (endpoint, candidate) = {
|
let (neuron_endpoint, candidate) = {
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let Some(node) = nodes.get(node_name) else {
|
let Some(node) = nodes.get(node_name) else {
|
||||||
anyhow::bail!("node '{node_name}' not found");
|
anyhow::bail!("node '{node_name}' not found");
|
||||||
};
|
};
|
||||||
|
|
||||||
// Find the loaded model with the oldest last_accessed, excluding pinned.
|
// Find the loaded model with the oldest last_accessed,
|
||||||
|
// excluding models pinned on this neuron (from catalogue).
|
||||||
let candidate = node
|
let candidate = node
|
||||||
.models
|
.models
|
||||||
.values()
|
.values()
|
||||||
.filter(|m| m.status == ModelStatus::Loaded)
|
.filter(|m| m.status == ModelStatus::Loaded)
|
||||||
.filter(|m| !node.pinned.contains(&m.id))
|
.filter(|m| !fleet.catalogue.is_pinned(&m.id, node_name))
|
||||||
.min_by_key(|m| m.last_accessed)
|
.min_by_key(|m| m.last_accessed)
|
||||||
.map(|m| m.id.clone());
|
.map(|m| m.id.clone());
|
||||||
|
|
||||||
@@ -58,18 +49,16 @@ pub async fn evict_lru_on_node(
|
|||||||
|
|
||||||
tracing::info!(node = node_name, model = %model_id, "evicting model");
|
tracing::info!(node = node_name, model = %model_id, "evicting model");
|
||||||
|
|
||||||
let url = format!("{endpoint}/v1/models/unload");
|
// Call neuron's unload endpoint.
|
||||||
|
let url = format!("{neuron_endpoint}/models/unload");
|
||||||
let resp = fleet
|
let resp = fleet
|
||||||
.http_client
|
.http_client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.json(&ModelLifecycleRequest {
|
.json(&serde_json::json!({ "model_id": model_id }))
|
||||||
model_id: model_id.clone(),
|
|
||||||
})
|
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if resp.status().is_success() {
|
if resp.status().is_success() {
|
||||||
// Update local state.
|
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
if let Some(node) = nodes.get_mut(node_name) {
|
if let Some(node) = nodes.get_mut(node_name) {
|
||||||
if let Some(entry) = node.models.get_mut(&model_id) {
|
if let Some(entry) = node.models.get_mut(&model_id) {
|
||||||
@@ -77,14 +66,13 @@ pub async fn evict_lru_on_node(
|
|||||||
}
|
}
|
||||||
node.lifecycle_cycles += 1;
|
node.lifecycle_cycles += 1;
|
||||||
|
|
||||||
// Check if we should flag for defrag.
|
|
||||||
if fleet.eviction.defrag_after_cycles > 0
|
if fleet.eviction.defrag_after_cycles > 0
|
||||||
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
|
&& node.lifecycle_cycles >= fleet.eviction.defrag_after_cycles
|
||||||
{
|
{
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
node = node_name,
|
node = node_name,
|
||||||
cycles = node.lifecycle_cycles,
|
cycles = node.lifecycle_cycles,
|
||||||
"VRAM fragmentation threshold reached — consider restarting mistralrs"
|
"VRAM fragmentation threshold reached — consider restarting harness"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ pub fn api_routes() -> Router<Arc<CortexState>> {
|
|||||||
Router::new()
|
Router::new()
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
.route("/v1/completions", post(completions))
|
.route("/v1/completions", post(completions))
|
||||||
|
.route("/v1/responses", post(responses))
|
||||||
.route("/v1/models", get(list_models))
|
.route("/v1/models", get(list_models))
|
||||||
.route("/v1/messages", post(anthropic_messages))
|
.route("/v1/messages", post(anthropic_messages))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
@@ -34,23 +35,94 @@ async fn chat_completions(
|
|||||||
) -> Response {
|
) -> Response {
|
||||||
let model_id = match extract_model(&body) {
|
let model_id = match extract_model(&body) {
|
||||||
Some(m) => m,
|
Some(m) => m,
|
||||||
None => return error_response(400, "missing 'model' field in request body"),
|
None => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "chat_completions",
|
||||||
|
"rejected: missing 'model' field in request body"
|
||||||
|
);
|
||||||
|
return error_response(400, "missing 'model' field in request body");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let route = match router::resolve(&fleet, &model_id).await {
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => return error_response(404, &e.to_string()),
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "chat_completions",
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"route resolve failed"
|
||||||
|
);
|
||||||
|
// RouteError's Display strings are short and informative
|
||||||
|
// ("model 'X' not found...", "no healthy nodes available")
|
||||||
|
// — fine to surface to the caller. The warn above carries
|
||||||
|
// any extra context for operators.
|
||||||
|
return error_response(404, &e.to_string());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||||
|
|
||||||
|
let body = rewrite_model_in_body(body, &route.resolved_model_id);
|
||||||
proxy_with_metrics(
|
proxy_with_metrics(
|
||||||
&fleet,
|
&fleet,
|
||||||
&route,
|
&route,
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
&model_id,
|
&route.resolved_model_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `POST /v1/responses` — proxy to the appropriate backend node.
|
||||||
|
///
|
||||||
|
/// Same routing shape as [`chat_completions`]: extract `model` from
|
||||||
|
/// the body, resolve to a node, forward verbatim. No translation —
|
||||||
|
/// neuron speaks the Responses API natively (see
|
||||||
|
/// `crates/neuron/src/wire/openai_responses.rs`), so the gateway is
|
||||||
|
/// a pass-through. Streaming and non-streaming are handled
|
||||||
|
/// identically; the upstream `Content-Type` (text/event-stream vs.
|
||||||
|
/// application/json) propagates through the proxy.
|
||||||
|
async fn responses(
|
||||||
|
State(fleet): State<Arc<CortexState>>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
body: Bytes,
|
||||||
|
) -> Response {
|
||||||
|
let model_id = match extract_model(&body) {
|
||||||
|
Some(m) => m,
|
||||||
|
None => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "responses",
|
||||||
|
"rejected: missing 'model' field in request body"
|
||||||
|
);
|
||||||
|
return error_response(400, "missing 'model' field in request body");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "responses",
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"route resolve failed"
|
||||||
|
);
|
||||||
|
return error_response(404, &e.to_string());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||||
|
|
||||||
|
let body = rewrite_model_in_body(body, &route.resolved_model_id);
|
||||||
|
proxy_with_metrics(
|
||||||
|
&fleet,
|
||||||
|
&route,
|
||||||
|
"/v1/responses",
|
||||||
|
headers,
|
||||||
|
body,
|
||||||
|
&route.resolved_model_id,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
@@ -63,17 +135,44 @@ async fn completions(
|
|||||||
) -> Response {
|
) -> Response {
|
||||||
let model_id = match extract_model(&body) {
|
let model_id = match extract_model(&body) {
|
||||||
Some(m) => m,
|
Some(m) => m,
|
||||||
None => return error_response(400, "missing 'model' field in request body"),
|
None => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "completions",
|
||||||
|
"rejected: missing 'model' field in request body"
|
||||||
|
);
|
||||||
|
return error_response(400, "missing 'model' field in request body");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let route = match router::resolve(&fleet, &model_id).await {
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => return error_response(404, &e.to_string()),
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "completions",
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"route resolve failed"
|
||||||
|
);
|
||||||
|
// RouteError's Display strings are short and informative
|
||||||
|
// ("model 'X' not found...", "no healthy nodes available")
|
||||||
|
// — fine to surface to the caller. The warn above carries
|
||||||
|
// any extra context for operators.
|
||||||
|
return error_response(404, &e.to_string());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||||
|
|
||||||
proxy_with_metrics(&fleet, &route, "/v1/completions", headers, body, &model_id).await
|
let body = rewrite_model_in_body(body, &route.resolved_model_id);
|
||||||
|
proxy_with_metrics(
|
||||||
|
&fleet,
|
||||||
|
&route,
|
||||||
|
"/v1/completions",
|
||||||
|
headers,
|
||||||
|
body,
|
||||||
|
&route.resolved_model_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
|
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
|
||||||
@@ -85,7 +184,14 @@ async fn anthropic_messages(
|
|||||||
// Parse as Anthropic request.
|
// Parse as Anthropic request.
|
||||||
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
|
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
error = %e,
|
||||||
|
"rejected: invalid Anthropic request body"
|
||||||
|
);
|
||||||
|
return error_response(400, "invalid Anthropic request body");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let model_id = anth_req.model.clone();
|
let model_id = anth_req.model.clone();
|
||||||
@@ -95,18 +201,43 @@ async fn anthropic_messages(
|
|||||||
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
|
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
|
||||||
let openai_body = match serde_json::to_vec(&openai_req) {
|
let openai_body = match serde_json::to_vec(&openai_req) {
|
||||||
Ok(b) => Bytes::from(b),
|
Ok(b) => Bytes::from(b),
|
||||||
Err(e) => return error_response(500, &format!("translation error: {e}")),
|
Err(e) => {
|
||||||
|
tracing::error!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"internal: failed to serialise translated OpenAI request"
|
||||||
|
);
|
||||||
|
return error_response(500, "internal translation error");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let route = match router::resolve(&fleet, &model_id).await {
|
let route = match router::resolve(&fleet, &model_id).await {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => return error_response(404, &e.to_string()),
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
error = %e,
|
||||||
|
"route resolve failed"
|
||||||
|
);
|
||||||
|
// RouteError's Display strings are short and informative
|
||||||
|
// ("model 'X' not found...", "no healthy nodes available")
|
||||||
|
// — fine to surface to the caller. The warn above carries
|
||||||
|
// any extra context for operators.
|
||||||
|
return error_response(404, &e.to_string());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||||
|
|
||||||
|
// Swap the alias for the concrete id in the translated body so
|
||||||
|
// neuron's harness sees a model name that matches what it has
|
||||||
|
// loaded.
|
||||||
|
let openai_body = rewrite_model_in_body(openai_body, &route.resolved_model_id);
|
||||||
|
|
||||||
let labels = [
|
let labels = [
|
||||||
("model", model_id.clone()),
|
("model", route.resolved_model_id.clone()),
|
||||||
("node", route.node_name.clone()),
|
("node", route.node_name.clone()),
|
||||||
];
|
];
|
||||||
metrics::counter!("cortex_requests_total", &labels).increment(1);
|
metrics::counter!("cortex_requests_total", &labels).increment(1);
|
||||||
@@ -133,14 +264,25 @@ async fn anthropic_messages(
|
|||||||
Ok(resp) => resp,
|
Ok(resp) => resp,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
|
// forward_request already warn'd with the wire-level
|
||||||
|
// detail; no need to log again here.
|
||||||
e.into_response()
|
e.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Non-streaming: proxy, buffer full response, translate back to Anthropic.
|
// Non-streaming: proxy, buffer full response, translate back to Anthropic.
|
||||||
|
let target_url = format!("{}/v1/chat/completions", route.endpoint);
|
||||||
|
tracing::info!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
cold_start = route.cold_start,
|
||||||
|
"proxying request"
|
||||||
|
);
|
||||||
let upstream_resp = fleet
|
let upstream_resp = fleet
|
||||||
.http_client
|
.http_client
|
||||||
.post(format!("{}/v1/chat/completions", route.endpoint))
|
.post(&target_url)
|
||||||
.body(openai_body)
|
.body(openai_body)
|
||||||
.header("content-type", "application/json")
|
.header("content-type", "application/json")
|
||||||
.send()
|
.send()
|
||||||
@@ -150,22 +292,49 @@ async fn anthropic_messages(
|
|||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
return error_response(502, &format!("upstream request failed: {e}"));
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
error = %e,
|
||||||
|
"upstream request failed (network)"
|
||||||
|
);
|
||||||
|
return error_response(502, "upstream request failed");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if !upstream_resp.status().is_success() {
|
let upstream_status = upstream_resp.status();
|
||||||
|
if !upstream_status.is_success() {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
let status = upstream_resp.status().as_u16();
|
let status = upstream_status.as_u16();
|
||||||
let body = upstream_resp.text().await.unwrap_or_default();
|
let body = upstream_resp.text().await.unwrap_or_default();
|
||||||
return error_response(status, &format!("upstream error: {body}"));
|
let body_snippet = body.chars().take(512).collect::<String>();
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
status,
|
||||||
|
body = %body_snippet,
|
||||||
|
"upstream returned non-2xx"
|
||||||
|
);
|
||||||
|
return error_response(status, &format!("upstream returned {status}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
let body_bytes = match upstream_resp.bytes().await {
|
let body_bytes = match upstream_resp.bytes().await {
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
return error_response(502, &format!("failed to read upstream response: {e}"));
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
error = %e,
|
||||||
|
"failed to read upstream response body"
|
||||||
|
);
|
||||||
|
return error_response(502, "failed to read upstream response");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -174,7 +343,20 @@ async fn anthropic_messages(
|
|||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
return error_response(502, &format!("failed to parse upstream response: {e}"));
|
let body_snippet = String::from_utf8_lossy(&body_bytes)
|
||||||
|
.chars()
|
||||||
|
.take(512)
|
||||||
|
.collect::<String>();
|
||||||
|
tracing::warn!(
|
||||||
|
handler = "anthropic_messages",
|
||||||
|
model = %model_id,
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %target_url,
|
||||||
|
error = %e,
|
||||||
|
body = %body_snippet,
|
||||||
|
"failed to parse upstream response as OpenAI ChatCompletionResponse"
|
||||||
|
);
|
||||||
|
return error_response(502, "malformed upstream response");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -185,12 +367,62 @@ async fn anthropic_messages(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `GET /v1/models` — aggregate models from all nodes.
|
/// `GET /v1/models` — union of (catalogue × topology feasibility) and
|
||||||
|
/// (currently loaded somewhere). The result is what the fleet *could*
|
||||||
|
/// serve, not just what's already loaded — so OpenAI-compatible tools
|
||||||
|
/// see every model the operator has provisioned, and cortex
|
||||||
|
/// transparently cold-loads the first time one is requested.
|
||||||
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
let now = Utc::now().timestamp() as u64;
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
|
let catalogue = &fleet.catalogue;
|
||||||
std::collections::HashMap::new();
|
|
||||||
|
|
||||||
|
let mut entries: HashMap<String, CortexModelEntry> = HashMap::new();
|
||||||
|
|
||||||
|
// Pass 1: catalogue × topology. For every catalogue profile, find
|
||||||
|
// healthy neurons whose discovered devices satisfy the profile.
|
||||||
|
// Catalogue-defined models surface here even if nothing has loaded
|
||||||
|
// them yet — that's the point of the unified endpoint.
|
||||||
|
for profile in &catalogue.models {
|
||||||
|
let mut feasible_on = Vec::new();
|
||||||
|
for node in nodes.values() {
|
||||||
|
if !node.healthy {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let Some(disc) = node.discovery.as_ref() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if profile.is_feasible_on(&node.name, &disc.devices) {
|
||||||
|
feasible_on.push(node.name.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if feasible_on.is_empty() {
|
||||||
|
// The catalogue lists this model but no neuron's topology
|
||||||
|
// matches — surface it as not-loaded with no feasible
|
||||||
|
// location. Hides nothing; lets operators see why a
|
||||||
|
// configured model isn't reachable.
|
||||||
|
feasible_on.clear();
|
||||||
|
}
|
||||||
|
entries.insert(
|
||||||
|
profile.id.clone(),
|
||||||
|
CortexModelEntry {
|
||||||
|
id: profile.id.clone(),
|
||||||
|
object: "model".into(),
|
||||||
|
created: now,
|
||||||
|
owned_by: "helexa".into(),
|
||||||
|
loaded: false,
|
||||||
|
feasible_on,
|
||||||
|
locations: Vec::new(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass 2: layer the actually-loaded state on top. For each
|
||||||
|
// (node, model) entry, attach a ModelLocation. If the model isn't
|
||||||
|
// in the catalogue, create a new CortexModelEntry from scratch —
|
||||||
|
// cortex doesn't refuse to surface a manually-loaded model just
|
||||||
|
// because the operator didn't enumerate it in models.toml.
|
||||||
for node in nodes.values() {
|
for node in nodes.values() {
|
||||||
for (model_id, entry) in &node.models {
|
for (model_id, entry) in &node.models {
|
||||||
let location = ModelLocation {
|
let location = ModelLocation {
|
||||||
@@ -198,19 +430,108 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
|||||||
status: entry.status,
|
status: entry.status,
|
||||||
vram_estimate_mb: entry.vram_estimate_mb,
|
vram_estimate_mb: entry.vram_estimate_mb,
|
||||||
};
|
};
|
||||||
model_map
|
let was_loaded = matches!(entry.status, cortex_core::node::ModelStatus::Loaded);
|
||||||
|
entries
|
||||||
.entry(model_id.clone())
|
.entry(model_id.clone())
|
||||||
.and_modify(|e| e.locations.push(location.clone()))
|
.and_modify(|e| {
|
||||||
|
e.locations.push(location.clone());
|
||||||
|
if was_loaded {
|
||||||
|
e.loaded = true;
|
||||||
|
}
|
||||||
|
})
|
||||||
.or_insert_with(|| CortexModelEntry {
|
.or_insert_with(|| CortexModelEntry {
|
||||||
id: model_id.clone(),
|
id: model_id.clone(),
|
||||||
object: "model".into(),
|
object: "model".into(),
|
||||||
|
created: now,
|
||||||
|
owned_by: "helexa".into(),
|
||||||
|
loaded: was_loaded,
|
||||||
|
// Not in catalogue — cortex has no opinion on
|
||||||
|
// feasibility; leave empty.
|
||||||
|
feasible_on: Vec::new(),
|
||||||
locations: vec![location],
|
locations: vec![location],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let data: Vec<Value> = model_map.values().map(|e| json!(e)).collect();
|
// Pass 3: surface pre-warming models. Each neuron's `/health`
|
||||||
|
// activation snapshot (polled separately from /models) reports
|
||||||
|
// `in_progress` (the model currently materialising) and `pending`
|
||||||
|
// (queued behind it). Neither appears on the neuron's `/models`
|
||||||
|
// yet — that endpoint only knows about fully-loaded handles — so
|
||||||
|
// without this pass a client polling `/v1/models` during pre-warm
|
||||||
|
// sees Qwen3.6-27B with no location and concludes "not there".
|
||||||
|
// Synthesising a Loading location instead tells clients the model
|
||||||
|
// is on its way. Idempotent against Pass 2: if a Loading location
|
||||||
|
// for this node already exists (shouldn't, but be safe) we skip.
|
||||||
|
for node in nodes.values() {
|
||||||
|
let Some(activation) = node.activation.as_ref() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let mut loading_ids: Vec<&str> = Vec::new();
|
||||||
|
if let Some(id) = activation.in_progress.as_deref() {
|
||||||
|
loading_ids.push(id);
|
||||||
|
}
|
||||||
|
for id in &activation.pending {
|
||||||
|
loading_ids.push(id.as_str());
|
||||||
|
}
|
||||||
|
for model_id in loading_ids {
|
||||||
|
let location = ModelLocation {
|
||||||
|
node: node.name.clone(),
|
||||||
|
status: cortex_core::node::ModelStatus::Loading,
|
||||||
|
vram_estimate_mb: None,
|
||||||
|
};
|
||||||
|
entries
|
||||||
|
.entry(model_id.to_string())
|
||||||
|
.and_modify(|e| {
|
||||||
|
let already = e.locations.iter().any(|l| {
|
||||||
|
l.node == node.name && l.status == cortex_core::node::ModelStatus::Loading
|
||||||
|
});
|
||||||
|
if !already {
|
||||||
|
e.locations.push(location.clone());
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.or_insert_with(|| CortexModelEntry {
|
||||||
|
id: model_id.to_string(),
|
||||||
|
object: "model".into(),
|
||||||
|
created: now,
|
||||||
|
owned_by: "helexa".into(),
|
||||||
|
loaded: false,
|
||||||
|
feasible_on: Vec::new(),
|
||||||
|
locations: vec![location],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass 4: surface aliases as their own entries pointing at the
|
||||||
|
// same locations as the target id, so a client browsing /v1/models
|
||||||
|
// sees "helexa/small" / "helexa/balanced" / "helexa/large" (or
|
||||||
|
// whatever the operator defined) and can request inference
|
||||||
|
// against them directly. Aliases that point at unknown targets
|
||||||
|
// are skipped — surfacing a dead alias would be misleading.
|
||||||
|
for (alias, target) in &catalogue.aliases {
|
||||||
|
let Some(target_entry) = entries.get(target).cloned() else {
|
||||||
|
tracing::warn!(
|
||||||
|
alias = alias,
|
||||||
|
target = target,
|
||||||
|
"alias points at a model not present in catalogue or fleet; skipping"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
entries.insert(
|
||||||
|
alias.clone(),
|
||||||
|
CortexModelEntry {
|
||||||
|
id: alias.clone(),
|
||||||
|
object: "model".into(),
|
||||||
|
created: now,
|
||||||
|
owned_by: "helexa".into(),
|
||||||
|
loaded: target_entry.loaded,
|
||||||
|
feasible_on: target_entry.feasible_on,
|
||||||
|
locations: target_entry.locations,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let data: Vec<Value> = entries.values().map(|e| json!(e)).collect();
|
||||||
Json(json!({
|
Json(json!({
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": data,
|
"data": data,
|
||||||
@@ -265,6 +586,9 @@ async fn proxy_with_metrics(
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||||
|
// proxy::forward_request already warn'd with wire-level
|
||||||
|
// detail (target URL, error, status). ProxyError::into_response
|
||||||
|
// now returns a generic message — no body leak.
|
||||||
e.into_response()
|
e.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -285,6 +609,38 @@ fn extract_model(body: &[u8]) -> Option<String> {
|
|||||||
v.get("model")?.as_str().map(|s| s.to_string())
|
v.get("model")?.as_str().map(|s| s.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Rewrite the `model` field of an OpenAI-style JSON request body to
|
||||||
|
/// the resolved concrete id. Returns the original bytes if `new_model`
|
||||||
|
/// matches what's already there or the body fails to parse — the
|
||||||
|
/// caller has already extracted `model` via `extract_model`, so a
|
||||||
|
/// parse failure here would only happen on a body the client crafted
|
||||||
|
/// to defeat us, and we'd rather proxy it unchanged than 500.
|
||||||
|
///
|
||||||
|
/// Needed because neuron rejects requests whose `model` field doesn't
|
||||||
|
/// match a loaded model, so a client that sends `model: "helexa/small"`
|
||||||
|
/// would hit a 404 at the harness unless we swap it for the concrete
|
||||||
|
/// id the alias resolved to.
|
||||||
|
fn rewrite_model_in_body(body: Bytes, new_model: &str) -> Bytes {
|
||||||
|
let Ok(mut v) = serde_json::from_slice::<Value>(&body) else {
|
||||||
|
return body;
|
||||||
|
};
|
||||||
|
let needs_rewrite = v
|
||||||
|
.get("model")
|
||||||
|
.and_then(|m| m.as_str())
|
||||||
|
.map(|m| m != new_model)
|
||||||
|
.unwrap_or(false);
|
||||||
|
if !needs_rewrite {
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
if let Value::Object(obj) = &mut v {
|
||||||
|
obj.insert("model".into(), Value::String(new_model.to_string()));
|
||||||
|
}
|
||||||
|
match serde_json::to_vec(&v) {
|
||||||
|
Ok(bytes) => Bytes::from(bytes),
|
||||||
|
Err(_) => body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn error_response(status: u16, message: &str) -> Response {
|
fn error_response(status: u16, message: &str) -> Response {
|
||||||
let code = axum::http::StatusCode::from_u16(status)
|
let code = axum::http::StatusCode::from_u16(status)
|
||||||
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
//! Background poller that periodically queries each node's `/v1/models`
|
//! Background poller that periodically queries each neuron's API
|
||||||
//! endpoint to refresh the fleet state.
|
//! to refresh the fleet state.
|
||||||
|
|
||||||
use crate::state::CortexState;
|
use crate::state::CortexState;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use cortex_core::node::{MistralModelsResponse, ModelEntry, ModelStatus};
|
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||||
|
use cortex_core::harness::ModelInfo;
|
||||||
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
const POLL_INTERVAL: Duration = Duration::from_secs(10);
|
const POLL_INTERVAL: Duration = Duration::from_secs(10);
|
||||||
|
|
||||||
/// Runs forever, polling all nodes on a fixed interval.
|
/// Runs forever, polling all neurons on a fixed interval.
|
||||||
pub async fn poll_loop(fleet: Arc<CortexState>) {
|
pub async fn poll_loop(fleet: Arc<CortexState>) {
|
||||||
loop {
|
loop {
|
||||||
poll_once(&fleet).await;
|
poll_once(&fleet).await;
|
||||||
@@ -17,15 +19,67 @@ pub async fn poll_loop(fleet: Arc<CortexState>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Poll all nodes once. Used by `poll_loop` and available for testing.
|
/// Poll all neurons once. Used by `poll_loop` and available for testing.
|
||||||
pub async fn poll_once(fleet: &CortexState) {
|
pub async fn poll_once(fleet: &CortexState) {
|
||||||
for nc in &fleet.node_configs {
|
for nc in &fleet.neuron_configs {
|
||||||
poll_node(fleet, &nc.name, &nc.endpoint).await;
|
poll_neuron(fleet, &nc.name, &nc.endpoint).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
/// One-shot fetch of `GET /discovery`. Cached on the NodeState forever
|
||||||
let url = format!("{endpoint}/v1/models");
|
/// after the first success — topology is invariant for a given neuron
|
||||||
|
/// process. Skipped when the cache is already populated.
|
||||||
|
async fn maybe_poll_discovery(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||||
|
{
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
match nodes.get(name) {
|
||||||
|
Some(n) if n.discovery.is_some() => return,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let url = format!("{endpoint}/discovery");
|
||||||
|
let resp = match fleet
|
||||||
|
.http_client
|
||||||
|
.get(&url)
|
||||||
|
.timeout(Duration::from_secs(5))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) if r.status().is_success() => r,
|
||||||
|
Ok(r) => {
|
||||||
|
tracing::debug!(node = name, status = %r.status(), "discovery probe non-success");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(node = name, error = %e, "discovery probe unreachable");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match resp.json::<DiscoveryResponse>().await {
|
||||||
|
Ok(d) => {
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
if let Some(node) = nodes.get_mut(name) {
|
||||||
|
tracing::info!(
|
||||||
|
node = name,
|
||||||
|
hostname = %d.hostname,
|
||||||
|
devices = d.devices.len(),
|
||||||
|
"discovery cached"
|
||||||
|
);
|
||||||
|
node.discovery = Some(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(node = name, error = %e, "failed to parse /discovery response");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||||
|
// Topology first — cheap once cached, and the router needs it to
|
||||||
|
// route requests against catalogue entries that aren't loaded yet.
|
||||||
|
maybe_poll_discovery(fleet, name, endpoint).await;
|
||||||
|
|
||||||
|
let url = format!("{endpoint}/models");
|
||||||
|
|
||||||
let result = fleet
|
let result = fleet
|
||||||
.http_client
|
.http_client
|
||||||
@@ -41,38 +95,36 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
|||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(resp) if resp.status().is_success() => {
|
Ok(resp) if resp.status().is_success() => {
|
||||||
match resp.json::<MistralModelsResponse>().await {
|
match resp.json::<Vec<ModelInfo>>().await {
|
||||||
Ok(models_resp) => {
|
Ok(models) => {
|
||||||
// Merge upstream model list into our state, preserving
|
|
||||||
// our local metadata (last_accessed, vram_estimate).
|
|
||||||
let mut seen = std::collections::HashSet::new();
|
let mut seen = std::collections::HashSet::new();
|
||||||
for upstream in &models_resp.data {
|
for upstream in &models {
|
||||||
seen.insert(upstream.id.clone());
|
seen.insert(upstream.id.clone());
|
||||||
let status = parse_status(upstream.status.as_deref());
|
let status = parse_status(&upstream.status);
|
||||||
|
|
||||||
node.models
|
node.models
|
||||||
.entry(upstream.id.clone())
|
.entry(upstream.id.clone())
|
||||||
.and_modify(|e| {
|
.and_modify(|e| {
|
||||||
e.status = status;
|
e.status = status;
|
||||||
|
e.vram_estimate_mb = upstream.vram_used_mb;
|
||||||
})
|
})
|
||||||
.or_insert_with(|| ModelEntry {
|
.or_insert_with(|| ModelEntry {
|
||||||
id: upstream.id.clone(),
|
id: upstream.id.clone(),
|
||||||
status,
|
status,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: None,
|
vram_estimate_mb: upstream.vram_used_mb,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove models that are no longer reported by the node
|
// Remove models no longer reported by the neuron.
|
||||||
// (e.g. after a config change / restart).
|
|
||||||
node.models.retain(|id, _| seen.contains(id));
|
node.models.retain(|id, _| seen.contains(id));
|
||||||
|
|
||||||
node.healthy = true;
|
node.healthy = true;
|
||||||
node.last_poll = Some(Utc::now());
|
node.last_poll = Some(Utc::now());
|
||||||
tracing::debug!(node = name, models = models_resp.data.len(), "poll ok");
|
tracing::debug!(node = name, models = models.len(), "poll ok");
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(node = name, error = %e, "failed to parse /v1/models response");
|
tracing::warn!(node = name, error = %e, "failed to parse /models response");
|
||||||
node.healthy = false;
|
node.healthy = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,24 +133,68 @@ async fn poll_node(fleet: &CortexState, name: &str, endpoint: &str) {
|
|||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
node = name,
|
node = name,
|
||||||
status = %resp.status(),
|
status = %resp.status(),
|
||||||
"node returned non-success status"
|
"neuron returned non-success status"
|
||||||
);
|
);
|
||||||
node.healthy = false;
|
node.healthy = false;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(node = name, error = %e, "failed to reach node");
|
tracing::warn!(node = name, error = %e, "failed to reach neuron");
|
||||||
node.healthy = false;
|
node.healthy = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Release the write lock before the next HTTP call.
|
||||||
|
drop(nodes);
|
||||||
|
|
||||||
|
// Poll /health for the activation snapshot. We don't want this to
|
||||||
|
// flip the node to unhealthy on its own — a neuron that's serving
|
||||||
|
// /models fine is still operational even if /health is briefly
|
||||||
|
// unavailable — so failures are debug-level and leave the existing
|
||||||
|
// activation reading in place.
|
||||||
|
poll_health(fleet, name, endpoint).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch `/health` and stash the activation snapshot on NodeState.
|
||||||
|
/// Decoupled from the /models poll so a /health glitch doesn't mark
|
||||||
|
/// the neuron unhealthy or evict the model list.
|
||||||
|
async fn poll_health(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||||
|
let url = format!("{endpoint}/health");
|
||||||
|
let resp = match fleet
|
||||||
|
.http_client
|
||||||
|
.get(&url)
|
||||||
|
.timeout(Duration::from_secs(5))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) if r.status().is_success() => r,
|
||||||
|
Ok(r) => {
|
||||||
|
tracing::debug!(node = name, status = %r.status(), "/health probe non-success");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(node = name, error = %e, "/health probe failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match resp.json::<HealthResponse>().await {
|
||||||
|
Ok(h) => {
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
if let Some(node) = nodes.get_mut(name) {
|
||||||
|
node.activation = Some(h.activation);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(node = name, error = %e, "failed to parse /health response");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_status(s: Option<&str>) -> ModelStatus {
|
fn parse_status(s: &str) -> ModelStatus {
|
||||||
match s {
|
match s {
|
||||||
Some("loaded") => ModelStatus::Loaded,
|
"loaded" => ModelStatus::Loaded,
|
||||||
Some("unloaded") => ModelStatus::Unloaded,
|
"unloaded" => ModelStatus::Unloaded,
|
||||||
Some("reloading") => ModelStatus::Reloading,
|
"reloading" => ModelStatus::Reloading,
|
||||||
// If the status field is absent, assume loaded (older mistral.rs versions
|
"loading" => ModelStatus::Loading,
|
||||||
// may not include it).
|
|
||||||
_ => ModelStatus::Loaded,
|
_ => ModelStatus::Loaded,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//! Streaming HTTP reverse proxy to mistral.rs backends.
|
//! Streaming HTTP reverse proxy to neuron backends.
|
||||||
//!
|
//!
|
||||||
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
||||||
//! The proxy captures timing information for metrics but does not
|
//! The proxy captures timing information for metrics but does not
|
||||||
@@ -12,6 +12,13 @@ use axum::response::{IntoResponse, Response};
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
|
|
||||||
/// Proxy a request body to the resolved backend node and stream the response.
|
/// Proxy a request body to the resolved backend node and stream the response.
|
||||||
|
///
|
||||||
|
/// Logging contract: every call emits exactly one structured event at
|
||||||
|
/// info / warn level for operator visibility, regardless of outcome.
|
||||||
|
/// Network-level failures and non-2xx upstream statuses are warn'd here
|
||||||
|
/// (closest to the wire); the user-facing response carries only the
|
||||||
|
/// status code and a generic message — implementation detail (body,
|
||||||
|
/// error chain) lives in the log, never in the API surface.
|
||||||
pub async fn forward_request(
|
pub async fn forward_request(
|
||||||
client: &Client,
|
client: &Client,
|
||||||
route: &RouteDecision,
|
route: &RouteDecision,
|
||||||
@@ -37,10 +44,33 @@ pub async fn forward_request(
|
|||||||
req_builder = req_builder.header(key, value);
|
req_builder = req_builder.header(key, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
|
let upstream_resp = match req_builder.send().await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %url,
|
||||||
|
error = %e,
|
||||||
|
"proxy: upstream request failed (network)"
|
||||||
|
);
|
||||||
|
return Err(ProxyError::Upstream(e));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let status =
|
let upstream_status = upstream_resp.status();
|
||||||
StatusCode::from_u16(upstream_resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
if !upstream_status.is_success() {
|
||||||
|
// Streaming body — can't snippet without breaking the stream
|
||||||
|
// pass-through. Log status + URL; the client still gets the
|
||||||
|
// upstream status, just without the leaked body.
|
||||||
|
tracing::warn!(
|
||||||
|
node = %route.node_name,
|
||||||
|
url = %url,
|
||||||
|
status = upstream_status.as_u16(),
|
||||||
|
"proxy: upstream returned non-2xx"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let status = StatusCode::from_u16(upstream_status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||||
|
|
||||||
let resp_headers = upstream_resp.headers().clone();
|
let resp_headers = upstream_resp.headers().clone();
|
||||||
let stream = upstream_resp.bytes_stream();
|
let stream = upstream_resp.bytes_stream();
|
||||||
@@ -52,28 +82,37 @@ pub async fn forward_request(
|
|||||||
response = response.header(key, value);
|
response = response.header(key, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
response
|
response.body(body).map_err(|e| {
|
||||||
.body(body)
|
tracing::warn!(
|
||||||
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
|
node = %route.node_name,
|
||||||
|
url = %url,
|
||||||
|
error = %e,
|
||||||
|
"proxy: failed to build response"
|
||||||
|
);
|
||||||
|
ProxyError::ResponseBuild(e.to_string())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum ProxyError {
|
pub enum ProxyError {
|
||||||
#[error("upstream request failed: {0}")]
|
#[error("upstream request failed")]
|
||||||
Upstream(reqwest::Error),
|
Upstream(reqwest::Error),
|
||||||
#[error("failed to build response: {0}")]
|
#[error("failed to build response")]
|
||||||
ResponseBuild(String),
|
ResponseBuild(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for ProxyError {
|
impl IntoResponse for ProxyError {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
let status = match &self {
|
let (status, message) = match &self {
|
||||||
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
|
ProxyError::Upstream(_) => (StatusCode::BAD_GATEWAY, "upstream request failed"),
|
||||||
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
ProxyError::ResponseBuild(_) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"failed to build response",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let body = serde_json::json!({
|
let body = serde_json::json!({
|
||||||
"error": {
|
"error": {
|
||||||
"message": self.to_string(),
|
"message": message,
|
||||||
"type": "proxy_error",
|
"type": "proxy_error",
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -2,74 +2,408 @@
|
|||||||
//!
|
//!
|
||||||
//! Given a model ID from an inbound request, determine which node should
|
//! Given a model ID from an inbound request, determine which node should
|
||||||
//! handle it. Priority:
|
//! handle it. Priority:
|
||||||
//! 1. Node where the model is currently `Loaded`
|
//! 1. Node where the model is currently `Loaded` → use it.
|
||||||
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
|
//! 2. Node where the model is `Unloaded` → use it; neuron's existing
|
||||||
//! 3. Error: model not found on any node
|
//! lazy-load behaviour will reload before serving the request.
|
||||||
|
//! 3. Model is in the catalogue → pick a feasible neuron, call
|
||||||
|
//! `POST /models/load`, wait for the load to complete, then
|
||||||
|
//! proxy. First-request cold-load latency is acceptable per the
|
||||||
|
//! unified-endpoint contract.
|
||||||
|
//! 4. Not in catalogue, not loaded anywhere → 404.
|
||||||
|
|
||||||
use crate::state::CortexState;
|
use crate::state::CortexState;
|
||||||
|
use cortex_core::catalogue::ModelProfile;
|
||||||
|
use cortex_core::harness::ModelSpec;
|
||||||
use cortex_core::node::ModelStatus;
|
use cortex_core::node::ModelStatus;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
/// The routing decision: which node endpoint to proxy the request to.
|
/// The routing decision: which node endpoint to proxy the request to.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RouteDecision {
|
pub struct RouteDecision {
|
||||||
pub node_name: String,
|
pub node_name: String,
|
||||||
|
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
/// Whether the model will need to load (cold start).
|
/// Whether the model will need to load (cold start). Set to true
|
||||||
|
/// when we proxied to an `Unloaded` node (lazy load on neuron) or
|
||||||
|
/// when we just triggered an explicit cold-load via the catalogue
|
||||||
|
/// path.
|
||||||
pub cold_start: bool,
|
pub cold_start: bool,
|
||||||
|
/// The concrete model id we actually routed to. Equal to the
|
||||||
|
/// caller's requested id unless an alias was resolved (e.g. caller
|
||||||
|
/// asked for `helexa/small`, this carries `Qwen/Qwen3-1.7B`). The
|
||||||
|
/// handler uses this to rewrite the request body's `model` field
|
||||||
|
/// before proxying — neurons reject requests where the body's
|
||||||
|
/// model name doesn't match a loaded model.
|
||||||
|
pub resolved_model_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum RouteError {
|
pub enum RouteError {
|
||||||
#[error("model '{0}' not found on any node")]
|
#[error("model '{0}' not found on any node and not in catalogue")]
|
||||||
ModelNotFound(String),
|
ModelNotFound(String),
|
||||||
#[error("no healthy nodes available")]
|
#[error("no healthy nodes available")]
|
||||||
NoHealthyNodes,
|
NoHealthyNodes,
|
||||||
|
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
|
||||||
|
EndpointResolveFailed(String, String),
|
||||||
|
#[error(
|
||||||
|
"model '{model_id}' is in the catalogue but no healthy neuron's topology satisfies its constraints"
|
||||||
|
)]
|
||||||
|
NoFeasibleNeuron { model_id: String },
|
||||||
|
#[error("cold-load of '{model_id}' on '{node}' failed: {message}")]
|
||||||
|
ColdLoadFailed {
|
||||||
|
model_id: String,
|
||||||
|
node: String,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve which node should serve a request for the given model.
|
/// Resolve which node should serve a request for the given model.
|
||||||
|
/// Asks the neuron for the inference endpoint after selecting a node.
|
||||||
pub async fn resolve(
|
pub async fn resolve(
|
||||||
fleet: &Arc<CortexState>,
|
fleet: &Arc<CortexState>,
|
||||||
model_id: &str,
|
requested_model_id: &str,
|
||||||
) -> Result<RouteDecision, RouteError> {
|
) -> Result<RouteDecision, RouteError> {
|
||||||
|
// Alias resolution first — swap `helexa/small` (etc.) for the
|
||||||
|
// concrete id before any node lookups so the rest of routing,
|
||||||
|
// loading, and metrics deal in concrete ids only. `resolve_alias`
|
||||||
|
// returns the input verbatim when it isn't an alias.
|
||||||
|
let model_id = fleet.catalogue.resolve_alias(requested_model_id);
|
||||||
|
if model_id != requested_model_id {
|
||||||
|
tracing::debug!(
|
||||||
|
requested = requested_model_id,
|
||||||
|
resolved = model_id,
|
||||||
|
"alias resolved"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Snapshot loaded / unloaded state from the poller cache.
|
||||||
|
let (loaded_route, unloaded_route, any_healthy) = {
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let mut loaded_route = None;
|
||||||
// Pass 1: find a node where the model is already loaded.
|
let mut unloaded_route = None;
|
||||||
let mut loaded_candidate = None;
|
let mut any_healthy = false;
|
||||||
let mut unloaded_candidate = None;
|
|
||||||
|
|
||||||
for node in nodes.values() {
|
for node in nodes.values() {
|
||||||
if !node.healthy {
|
if !node.healthy {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
any_healthy = true;
|
||||||
if let Some(entry) = node.models.get(model_id) {
|
if let Some(entry) = node.models.get(model_id) {
|
||||||
match entry.status {
|
match entry.status {
|
||||||
ModelStatus::Loaded | ModelStatus::Reloading => {
|
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||||
loaded_candidate = Some(RouteDecision {
|
loaded_route = Some((node.name.clone(), node.endpoint.clone(), false));
|
||||||
node_name: node.name.clone(),
|
break;
|
||||||
endpoint: node.endpoint.clone(),
|
|
||||||
cold_start: false,
|
|
||||||
});
|
|
||||||
break; // loaded is best, stop searching
|
|
||||||
}
|
}
|
||||||
ModelStatus::Unloaded => {
|
ModelStatus::Unloaded => {
|
||||||
if unloaded_candidate.is_none() {
|
if unloaded_route.is_none() {
|
||||||
unloaded_candidate = Some(RouteDecision {
|
unloaded_route = Some((node.name.clone(), node.endpoint.clone(), true));
|
||||||
node_name: node.name.clone(),
|
}
|
||||||
endpoint: node.endpoint.clone(),
|
}
|
||||||
cold_start: true,
|
// Loading is gateway-synthesised from neuron's
|
||||||
});
|
// activation snapshot; it never appears on the
|
||||||
}
|
// wire from neuron's `/models`. Skip — the model
|
||||||
|
// isn't actually servable yet. The pre-existing
|
||||||
|
// race (catalogue cold_load fires a parallel
|
||||||
|
// /models/load against the in-flight load) is no
|
||||||
|
// worse than before; fixing it needs neuron-side
|
||||||
|
// in-flight tracking on /models/load itself.
|
||||||
|
ModelStatus::Loading => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
(loaded_route, unloaded_route, any_healthy)
|
||||||
|
};
|
||||||
|
|
||||||
|
if !any_healthy {
|
||||||
|
return Err(RouteError::NoHealthyNodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
|
// Priority 1: already loaded.
|
||||||
if nodes.values().any(|n| n.healthy) {
|
if let Some((node_name, neuron_endpoint, cold_start)) = loaded_route {
|
||||||
RouteError::ModelNotFound(model_id.to_string())
|
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
|
||||||
} else {
|
|
||||||
RouteError::NoHealthyNodes
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Priority 2: known to neuron but unloaded (neuron's lazy load).
|
||||||
|
if let Some((node_name, neuron_endpoint, cold_start)) = unloaded_route {
|
||||||
|
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 3: catalogue × topology cold-load.
|
||||||
|
if let Some(profile) = fleet.catalogue.get(model_id) {
|
||||||
|
let (node_name, neuron_endpoint) = pick_feasible_neuron(fleet, profile).await?;
|
||||||
|
cold_load(fleet, &node_name, &neuron_endpoint, profile).await?;
|
||||||
|
return finish(fleet, &node_name, &neuron_endpoint, model_id, true).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(RouteError::ModelNotFound(model_id.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pick a healthy neuron whose discovered topology satisfies the
|
||||||
|
/// profile. Preference order:
|
||||||
|
/// 1. A neuron from `profile.pinned_on` that is healthy + feasible.
|
||||||
|
/// 2. Otherwise, any healthy + feasible neuron, stable by name.
|
||||||
|
async fn pick_feasible_neuron(
|
||||||
|
fleet: &Arc<CortexState>,
|
||||||
|
profile: &ModelProfile,
|
||||||
|
) -> Result<(String, String), RouteError> {
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let mut candidates: Vec<(String, String, bool)> = Vec::new();
|
||||||
|
for node in nodes.values() {
|
||||||
|
if !node.healthy {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let Some(disc) = node.discovery.as_ref() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if !profile.is_feasible_on(&node.name, &disc.devices) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let pinned = profile.pinned_on.iter().any(|n| n == &node.name);
|
||||||
|
candidates.push((node.name.clone(), node.endpoint.clone(), pinned));
|
||||||
|
}
|
||||||
|
candidates.sort_by(|a, b| {
|
||||||
|
b.2.cmp(&a.2) // pinned first (true > false)
|
||||||
|
.then(a.0.cmp(&b.0))
|
||||||
|
});
|
||||||
|
let pick = candidates.into_iter().next();
|
||||||
|
pick.map(|(n, e, _)| (n, e))
|
||||||
|
.ok_or_else(|| RouteError::NoFeasibleNeuron {
|
||||||
|
model_id: profile.id.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Issue `POST {endpoint}/models/load` for this profile on this neuron,
|
||||||
|
/// blocking until the load completes (neuron's load endpoint is
|
||||||
|
/// synchronous — it returns 200 once VRAM is materialised). On success
|
||||||
|
/// also inserts a `Loaded` entry into the local NodeState cache so the
|
||||||
|
/// caller's subsequent endpoint lookup sees the new model without
|
||||||
|
/// waiting for the next poll cycle.
|
||||||
|
async fn cold_load(
|
||||||
|
fleet: &Arc<CortexState>,
|
||||||
|
node_name: &str,
|
||||||
|
neuron_endpoint: &str,
|
||||||
|
profile: &ModelProfile,
|
||||||
|
) -> Result<(), RouteError> {
|
||||||
|
let spec = profile_to_spec(fleet, node_name, profile).await;
|
||||||
|
let url = format!("{neuron_endpoint}/models/load");
|
||||||
|
tracing::info!(model = %profile.id, node = node_name, "cold-loading via /models/load");
|
||||||
|
|
||||||
|
// Generous timeout: a fresh download + safetensors mmap + device
|
||||||
|
// copy for a 30B-class dense model can comfortably exceed 5 min on
|
||||||
|
// a slow link. The HTTP client's own default already covers most
|
||||||
|
// of this; pin a longer per-request bound just here.
|
||||||
|
let resp = match fleet
|
||||||
|
.http_client
|
||||||
|
.post(&url)
|
||||||
|
.timeout(Duration::from_secs(1800))
|
||||||
|
.json(&spec)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
return Err(RouteError::ColdLoadFailed {
|
||||||
|
model_id: profile.id.clone(),
|
||||||
|
node: node_name.to_string(),
|
||||||
|
message: format!("HTTP request failed: {e}"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
// Neuron returns 400 "already loaded" when two concurrent
|
||||||
|
// requests race the same model. Treat that as success — both
|
||||||
|
// requests effectively achieved the same end state.
|
||||||
|
if body.contains("already loaded") {
|
||||||
|
tracing::info!(
|
||||||
|
model = %profile.id,
|
||||||
|
node = node_name,
|
||||||
|
"cold-load saw 'already loaded' — treating as success"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return Err(RouteError::ColdLoadFailed {
|
||||||
|
model_id: profile.id.clone(),
|
||||||
|
node: node_name.to_string(),
|
||||||
|
message: format!("HTTP {status}: {body}"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::info!(model = %profile.id, node = node_name, "cold-load returned 200");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warm the cache: insert a Loaded ModelEntry so the next
|
||||||
|
// resolve() finds the model without waiting for the poll loop.
|
||||||
|
{
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
if let Some(node) = nodes.get_mut(node_name) {
|
||||||
|
node.models.insert(
|
||||||
|
profile.id.clone(),
|
||||||
|
cortex_core::node::ModelEntry {
|
||||||
|
id: profile.id.clone(),
|
||||||
|
status: ModelStatus::Loaded,
|
||||||
|
last_accessed: Some(chrono::Utc::now()),
|
||||||
|
vram_estimate_mb: profile.vram_mb,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Translate a `ModelProfile` to a `ModelSpec` neuron's /models/load
|
||||||
|
/// accepts. Devices are picked from the neuron's discovered topology —
|
||||||
|
/// the first `min_devices` indices that meet `min_device_vram_mb`.
|
||||||
|
async fn profile_to_spec(
|
||||||
|
fleet: &Arc<CortexState>,
|
||||||
|
node_name: &str,
|
||||||
|
profile: &ModelProfile,
|
||||||
|
) -> ModelSpec {
|
||||||
|
let devices = {
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let mut picked: Vec<u32> = Vec::new();
|
||||||
|
if let Some(node) = nodes.get(node_name)
|
||||||
|
&& let Some(disc) = &node.discovery
|
||||||
|
{
|
||||||
|
let min_vram = profile.min_device_vram_mb.unwrap_or(0);
|
||||||
|
for d in &disc.devices {
|
||||||
|
if d.vram_total_mb >= min_vram {
|
||||||
|
picked.push(d.index);
|
||||||
|
if picked.len() as u32 >= profile.min_devices {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if picked.is_empty() {
|
||||||
|
// Fall back to a 0..min_devices default; pick_feasible_neuron
|
||||||
|
// already verified the topology satisfies the constraints,
|
||||||
|
// so this only fires if discovery raced or was lost.
|
||||||
|
(0..profile.min_devices).collect()
|
||||||
|
} else {
|
||||||
|
picked
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let tensor_parallel = if profile.min_devices > 1 {
|
||||||
|
Some(profile.min_devices)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
ModelSpec {
|
||||||
|
model_id: profile.id.clone(),
|
||||||
|
harness: profile.harness.clone(),
|
||||||
|
quant: profile.quant.clone(),
|
||||||
|
tensor_parallel,
|
||||||
|
devices: Some(devices),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
|
||||||
|
/// build the final `RouteDecision`. Shared by all three priority
|
||||||
|
/// branches above.
|
||||||
|
async fn finish(
|
||||||
|
fleet: &Arc<CortexState>,
|
||||||
|
node_name: &str,
|
||||||
|
neuron_endpoint: &str,
|
||||||
|
model_id: &str,
|
||||||
|
cold_start: bool,
|
||||||
|
) -> Result<RouteDecision, RouteError> {
|
||||||
|
let endpoint_url = format!(
|
||||||
|
"{}/models/{}/endpoint",
|
||||||
|
neuron_endpoint,
|
||||||
|
urlencoding::encode(model_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
let inference_endpoint = match fleet.http_client.get(&endpoint_url).send().await {
|
||||||
|
Ok(resp) if resp.status().is_success() => match resp.json::<serde_json::Value>().await {
|
||||||
|
Ok(body) => body
|
||||||
|
.get("url")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string()),
|
||||||
|
Err(_) => None,
|
||||||
|
},
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let raw = inference_endpoint.ok_or_else(|| {
|
||||||
|
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Rewrite loopback inference URLs to use the configured neuron host.
|
||||||
|
// Neuron's default bind_url is `http://localhost:13131` (it can't
|
||||||
|
// reliably know its own externally-resolvable name). Cortex sees a
|
||||||
|
// URL that's only meaningful from the neuron host's own perspective;
|
||||||
|
// proxying directly to localhost from a different cortex host would
|
||||||
|
// hit nothing. Keep neuron's port and path (a future harness could
|
||||||
|
// serve inference on a different port than the management API), but
|
||||||
|
// swap the host for the one in cortex.toml.
|
||||||
|
let endpoint = rewrite_loopback_host(&raw, neuron_endpoint).unwrap_or(raw);
|
||||||
|
|
||||||
|
Ok(RouteDecision {
|
||||||
|
node_name: node_name.to_string(),
|
||||||
|
endpoint,
|
||||||
|
cold_start,
|
||||||
|
resolved_model_id: model_id.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If `inference_url`'s host is a loopback name (localhost / 127.0.0.1 /
|
||||||
|
/// 0.0.0.0 / ::1), return a copy with the host replaced by
|
||||||
|
/// `neuron_endpoint`'s host. Otherwise return None and the caller falls
|
||||||
|
/// back to the inference URL as-is.
|
||||||
|
fn rewrite_loopback_host(inference_url: &str, neuron_endpoint: &str) -> Option<String> {
|
||||||
|
let inf = url::Url::parse(inference_url).ok()?;
|
||||||
|
let inf_host = inf.host_str()?;
|
||||||
|
let is_loopback = matches!(inf_host, "localhost" | "127.0.0.1" | "0.0.0.0" | "::1");
|
||||||
|
if !is_loopback {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let neuron = url::Url::parse(neuron_endpoint).ok()?;
|
||||||
|
let new_host = neuron.host_str()?;
|
||||||
|
let mut out = inf.clone();
|
||||||
|
out.set_host(Some(new_host)).ok()?;
|
||||||
|
// url::Url::to_string normalises an empty path to "/", which then
|
||||||
|
// breaks downstream callers that do format!("{endpoint}/v1/...")
|
||||||
|
// and produce a double slash. The proxy URL is treated as a base
|
||||||
|
// string that the caller appends paths to, so strip the trailing
|
||||||
|
// slash here.
|
||||||
|
let s = out.to_string();
|
||||||
|
Some(s.trim_end_matches('/').to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::rewrite_loopback_host;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rewrites_localhost_keeps_port_and_path() {
|
||||||
|
let out = rewrite_loopback_host(
|
||||||
|
"http://localhost:13131",
|
||||||
|
"http://beast.hanzalova.internal:13131",
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
out.as_deref(),
|
||||||
|
Some("http://beast.hanzalova.internal:13131")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rewrites_loopback_with_distinct_inference_port() {
|
||||||
|
let out = rewrite_loopback_host("http://127.0.0.1:8080", "http://beast.lan:13131");
|
||||||
|
assert_eq!(out.as_deref(), Some("http://beast.lan:8080"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn leaves_non_loopback_alone() {
|
||||||
|
let out = rewrite_loopback_host("http://other.host:1234", "http://beast.lan:13131");
|
||||||
|
assert_eq!(out, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn malformed_inference_url_returns_none() {
|
||||||
|
let out = rewrite_loopback_host("not a url", "http://beast.lan:13131");
|
||||||
|
assert_eq!(out, None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use cortex_core::config::{EvictionSettings, GatewayConfig, NodeConfig};
|
use cortex_core::catalogue::ModelCatalogue;
|
||||||
|
use cortex_core::config::{EvictionSettings, GatewayConfig, NeuronEndpoint};
|
||||||
use cortex_core::node::NodeState;
|
use cortex_core::node::NodeState;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
@@ -6,34 +7,38 @@ use tokio::sync::RwLock;
|
|||||||
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
/// Shared fleet state, protected by a RwLock for concurrent reader access.
|
||||||
pub struct CortexState {
|
pub struct CortexState {
|
||||||
pub nodes: RwLock<HashMap<String, NodeState>>,
|
pub nodes: RwLock<HashMap<String, NodeState>>,
|
||||||
pub node_configs: Vec<NodeConfig>,
|
pub neuron_configs: Vec<NeuronEndpoint>,
|
||||||
pub eviction: EvictionSettings,
|
pub eviction: EvictionSettings,
|
||||||
|
pub catalogue: ModelCatalogue,
|
||||||
pub http_client: reqwest::Client,
|
pub http_client: reqwest::Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CortexState {
|
impl CortexState {
|
||||||
pub fn from_config(config: &GatewayConfig) -> Self {
|
pub fn from_config(config: &GatewayConfig) -> Self {
|
||||||
let mut nodes = HashMap::new();
|
let mut nodes = HashMap::new();
|
||||||
for nc in &config.nodes {
|
for nc in &config.neurons {
|
||||||
nodes.insert(
|
nodes.insert(
|
||||||
nc.name.clone(),
|
nc.name.clone(),
|
||||||
NodeState {
|
NodeState {
|
||||||
name: nc.name.clone(),
|
name: nc.name.clone(),
|
||||||
endpoint: nc.endpoint.clone(),
|
endpoint: nc.endpoint.clone(),
|
||||||
vram_mb: nc.vram_mb,
|
healthy: false,
|
||||||
pinned: nc.pinned.clone(),
|
|
||||||
healthy: false, // will be set by first poll
|
|
||||||
models: HashMap::new(),
|
models: HashMap::new(),
|
||||||
lifecycle_cycles: 0,
|
lifecycle_cycles: 0,
|
||||||
last_poll: None,
|
last_poll: None,
|
||||||
|
discovery: None,
|
||||||
|
activation: None,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let catalogue = ModelCatalogue::load(&config.models_config);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
nodes: RwLock::new(nodes),
|
nodes: RwLock::new(nodes),
|
||||||
node_configs: config.nodes.clone(),
|
neuron_configs: config.neurons.clone(),
|
||||||
eviction: config.eviction.clone(),
|
eviction: config.eviction.clone(),
|
||||||
|
catalogue,
|
||||||
http_client: reqwest::Client::builder()
|
http_client: reqwest::Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(300))
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
.build()
|
.build()
|
||||||
|
|||||||
265
crates/cortex-gateway/tests/aliases.rs
Normal file
265
crates/cortex-gateway/tests/aliases.rs
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
//! Alias resolution: a client request with `model: "helexa/small"`
|
||||||
|
//! routes to the concrete model id (e.g. `Qwen/Qwen3-1.7B`), with the
|
||||||
|
//! proxied request body rewritten so the upstream neuron sees a model
|
||||||
|
//! name that matches its loaded handle.
|
||||||
|
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use cortex_core::config::{
|
||||||
|
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
|
||||||
|
};
|
||||||
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
|
use cortex_gateway::state::CortexState;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
/// Write a `models.toml` with one alias to a unique temp path. Returns
|
||||||
|
/// the path; the file persists for the test process and gets reaped by
|
||||||
|
/// the OS at exit. Using $XDG_RUNTIME_DIR fallback for the temp dir
|
||||||
|
/// keeps the file off shared /tmp on CI without pulling in tempfile.
|
||||||
|
fn write_models_toml(alias: &str, target: &str) -> PathBuf {
|
||||||
|
let contents = format!(
|
||||||
|
r#"
|
||||||
|
[aliases]
|
||||||
|
"{alias}" = "{target}"
|
||||||
|
"#
|
||||||
|
);
|
||||||
|
let mut path = std::env::temp_dir();
|
||||||
|
let pid = std::process::id();
|
||||||
|
let now = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_nanos();
|
||||||
|
path.push(format!("cortex-test-models-{pid}-{now}.toml"));
|
||||||
|
std::fs::write(&path, contents).expect("write temp models.toml");
|
||||||
|
path
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_alias_resolves_in_chat_completions() {
|
||||||
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
|
let models_path = write_models_toml("helexa/small", "test-model");
|
||||||
|
|
||||||
|
let config = GatewayConfig {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "127.0.0.1:0".into(),
|
||||||
|
metrics_listen: "127.0.0.1:0".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 0,
|
||||||
|
},
|
||||||
|
neurons: vec![NeuronEndpoint {
|
||||||
|
name: "mock-node".into(),
|
||||||
|
endpoint: mock_url,
|
||||||
|
}],
|
||||||
|
models_config: models_path.to_string_lossy().to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
|
// Seed the node as healthy with the concrete model loaded under
|
||||||
|
// the target id. The poller doesn't run in this test; we just
|
||||||
|
// populate state manually.
|
||||||
|
{
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||||
|
node.healthy = true;
|
||||||
|
node.models.insert(
|
||||||
|
"test-model".into(),
|
||||||
|
ModelEntry {
|
||||||
|
id: "test-model".into(),
|
||||||
|
status: ModelStatus::Loaded,
|
||||||
|
last_accessed: None,
|
||||||
|
vram_estimate_mb: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanity: the catalogue actually picked up the alias.
|
||||||
|
assert_eq!(
|
||||||
|
fleet.catalogue.resolve_alias("helexa/small"),
|
||||||
|
"test-model",
|
||||||
|
"alias should resolve to target id"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Spawn the gateway against this fleet.
|
||||||
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let gateway_addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let gateway_url = format!("http://{gateway_addr}");
|
||||||
|
|
||||||
|
// Send a chat completion against the alias. The mock backend
|
||||||
|
// echoes back the `model` field it received — so a body whose
|
||||||
|
// model wasn't rewritten would come back as "helexa/small", and a
|
||||||
|
// properly-rewritten one as "test-model".
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.post(format!("{gateway_url}/v1/chat/completions"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "helexa/small",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("gateway should respond");
|
||||||
|
|
||||||
|
assert!(resp.status().is_success(), "gateway returned non-2xx");
|
||||||
|
let body: serde_json::Value = resp.json().await.expect("response is JSON");
|
||||||
|
assert_eq!(
|
||||||
|
body.get("model").and_then(|m| m.as_str()),
|
||||||
|
Some("test-model"),
|
||||||
|
"mock backend should have seen the resolved model id, not the alias"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_aliases_surface_in_v1_models() {
|
||||||
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
|
let models_path = write_models_toml("helexa/small", "test-model");
|
||||||
|
|
||||||
|
let config = GatewayConfig {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "127.0.0.1:0".into(),
|
||||||
|
metrics_listen: "127.0.0.1:0".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 0,
|
||||||
|
},
|
||||||
|
neurons: vec![NeuronEndpoint {
|
||||||
|
name: "mock-node".into(),
|
||||||
|
endpoint: mock_url,
|
||||||
|
}],
|
||||||
|
models_config: models_path.to_string_lossy().to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
|
// Seed the target as loaded so the alias's mirrored entry shows
|
||||||
|
// loaded=true.
|
||||||
|
{
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||||
|
node.healthy = true;
|
||||||
|
node.models.insert(
|
||||||
|
"test-model".into(),
|
||||||
|
ModelEntry {
|
||||||
|
id: "test-model".into(),
|
||||||
|
status: ModelStatus::Loaded,
|
||||||
|
last_accessed: None,
|
||||||
|
vram_estimate_mb: Some(2000),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let gateway_addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let gateway_url = format!("http://{gateway_addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::get(format!("{gateway_url}/v1/models"))
|
||||||
|
.await
|
||||||
|
.expect("gateway should respond");
|
||||||
|
let body: serde_json::Value = resp.json().await.unwrap();
|
||||||
|
let entries = body
|
||||||
|
.get("data")
|
||||||
|
.and_then(|d| d.as_array())
|
||||||
|
.expect("data array");
|
||||||
|
|
||||||
|
// Both the alias and the target should be present.
|
||||||
|
let ids: Vec<&str> = entries
|
||||||
|
.iter()
|
||||||
|
.filter_map(|e| e.get("id").and_then(|v| v.as_str()))
|
||||||
|
.collect();
|
||||||
|
assert!(ids.contains(&"test-model"), "target should be listed");
|
||||||
|
assert!(ids.contains(&"helexa/small"), "alias should be listed");
|
||||||
|
|
||||||
|
// The alias's `loaded` flag and locations should mirror the target.
|
||||||
|
let alias_entry = entries
|
||||||
|
.iter()
|
||||||
|
.find(|e| e.get("id").and_then(|v| v.as_str()) == Some("helexa/small"))
|
||||||
|
.expect("alias entry");
|
||||||
|
assert_eq!(alias_entry.get("loaded"), Some(&json!(true)));
|
||||||
|
let locations = alias_entry
|
||||||
|
.get("locations")
|
||||||
|
.and_then(|l| l.as_array())
|
||||||
|
.expect("locations array");
|
||||||
|
assert_eq!(locations.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
locations[0].get("node").and_then(|n| n.as_str()),
|
||||||
|
Some("mock-node")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_alias_falls_through_for_unmapped_model() {
|
||||||
|
// Catalogue has an alias for some-other-thing but the request
|
||||||
|
// model "test-model" isn't an alias; resolution should be a no-op.
|
||||||
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
|
let models_path = write_models_toml("helexa/large", "definitely-not-loaded");
|
||||||
|
|
||||||
|
let config = GatewayConfig {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "127.0.0.1:0".into(),
|
||||||
|
metrics_listen: "127.0.0.1:0".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 0,
|
||||||
|
},
|
||||||
|
neurons: vec![NeuronEndpoint {
|
||||||
|
name: "mock-node".into(),
|
||||||
|
endpoint: mock_url,
|
||||||
|
}],
|
||||||
|
models_config: models_path.to_string_lossy().to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
{
|
||||||
|
let mut nodes = fleet.nodes.write().await;
|
||||||
|
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||||
|
node.healthy = true;
|
||||||
|
node.models.insert(
|
||||||
|
"test-model".into(),
|
||||||
|
ModelEntry {
|
||||||
|
id: "test-model".into(),
|
||||||
|
status: ModelStatus::Loaded,
|
||||||
|
last_accessed: None,
|
||||||
|
vram_estimate_mb: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let gateway_addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
let gateway_url = format!("http://{gateway_addr}");
|
||||||
|
|
||||||
|
let resp = reqwest::Client::new()
|
||||||
|
.post(format!("{gateway_url}/v1/chat/completions"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(resp.status().is_success());
|
||||||
|
let body: serde_json::Value = resp.json().await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
body.get("model").and_then(|m| m.as_str()),
|
||||||
|
Some("test-model")
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ use serde_json::json;
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_to_openai_round_trip() {
|
async fn test_anthropic_to_openai_round_trip() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -14,9 +14,7 @@ async fn test_anthropic_to_openai_round_trip() {
|
|||||||
.json(&json!({
|
.json(&json!({
|
||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": "Hi"}]
|
||||||
{"role": "user", "content": "Hi"}
|
|
||||||
]
|
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
@@ -25,29 +23,22 @@ async fn test_anthropic_to_openai_round_trip() {
|
|||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
||||||
|
|
||||||
// Response should be in Anthropic format.
|
|
||||||
assert_eq!(body["type"], "message");
|
assert_eq!(body["type"], "message");
|
||||||
assert_eq!(body["role"], "assistant");
|
assert_eq!(body["role"], "assistant");
|
||||||
assert_eq!(body["model"], "test-model");
|
assert_eq!(body["model"], "test-model");
|
||||||
|
|
||||||
// Content should be an array of content blocks.
|
|
||||||
let content = body["content"].as_array().expect("content array");
|
let content = body["content"].as_array().expect("content array");
|
||||||
assert_eq!(content.len(), 1);
|
assert_eq!(content.len(), 1);
|
||||||
assert_eq!(content[0]["type"], "text");
|
assert_eq!(content[0]["type"], "text");
|
||||||
assert_eq!(content[0]["text"], "Hello from mock backend");
|
assert_eq!(content[0]["text"], "Hello from mock backend");
|
||||||
|
|
||||||
// Stop reason should be translated from "stop" to "end_turn".
|
|
||||||
assert_eq!(body["stop_reason"], "end_turn");
|
assert_eq!(body["stop_reason"], "end_turn");
|
||||||
|
|
||||||
// Usage should have Anthropic field names.
|
|
||||||
assert_eq!(body["usage"]["input_tokens"], 10);
|
assert_eq!(body["usage"]["input_tokens"], 10);
|
||||||
assert_eq!(body["usage"]["output_tokens"], 5);
|
assert_eq!(body["usage"]["output_tokens"], 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_with_system_prompt() {
|
async fn test_anthropic_with_system_prompt() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -58,24 +49,20 @@ async fn test_anthropic_with_system_prompt() {
|
|||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"system": "You are a helpful assistant.",
|
"system": "You are a helpful assistant.",
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": "Hi"}]
|
||||||
{"role": "user", "content": "Hi"}
|
|
||||||
]
|
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
|
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
||||||
assert_eq!(body["type"], "message");
|
assert_eq!(body["type"], "message");
|
||||||
assert_eq!(body["content"][0]["text"], "Hello from mock backend");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_with_content_blocks() {
|
async fn test_anthropic_with_content_blocks() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -85,29 +72,23 @@ async fn test_anthropic_with_content_blocks() {
|
|||||||
.json(&json!({
|
.json(&json!({
|
||||||
"model": "test-model",
|
"model": "test-model",
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"messages": [
|
"messages": [{
|
||||||
{
|
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [{"type": "text", "text": "What is this?"}]
|
||||||
{"type": "text", "text": "What is this?"}
|
}]
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
|
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
let body: serde_json::Value = resp.json().await.expect("valid JSON");
|
||||||
assert_eq!(body["type"], "message");
|
assert_eq!(body["type"], "message");
|
||||||
assert_eq!(body["content"][0]["text"], "Hello from mock backend");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_model_not_found() {
|
async fn test_anthropic_model_not_found() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -117,9 +98,7 @@ async fn test_anthropic_model_not_found() {
|
|||||||
.json(&json!({
|
.json(&json!({
|
||||||
"model": "nonexistent",
|
"model": "nonexistent",
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": "Hi"}]
|
||||||
{"role": "user", "content": "Hi"}
|
|
||||||
]
|
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
@@ -130,27 +109,17 @@ async fn test_anthropic_model_not_found() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_anthropic_invalid_request() {
|
async fn test_anthropic_invalid_request() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp = client
|
let resp = client
|
||||||
.post(format!("{gw_url}/v1/messages"))
|
.post(format!("{gw_url}/v1/messages"))
|
||||||
.header("content-type", "application/json")
|
.header("content-type", "application/json")
|
||||||
.json(&json!({
|
.json(&json!({"not_a_valid": "request"}))
|
||||||
"not_a_valid": "request"
|
|
||||||
}))
|
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
|
|
||||||
assert_eq!(resp.status(), 400);
|
assert_eq!(resp.status(), 400);
|
||||||
|
|
||||||
let body: serde_json::Value = resp.json().await.unwrap();
|
|
||||||
assert!(
|
|
||||||
body["error"]["message"]
|
|
||||||
.as_str()
|
|
||||||
.unwrap()
|
|
||||||
.contains("invalid Anthropic request")
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use axum::body::Body;
|
use axum::body::Body;
|
||||||
|
use axum::extract::Path;
|
||||||
use axum::http::header;
|
use axum::http::header;
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use cortex_core::config::{
|
use cortex_core::config::{
|
||||||
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NodeConfig,
|
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
|
||||||
};
|
};
|
||||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
use cortex_gateway::state::CortexState;
|
use cortex_gateway::state::CortexState;
|
||||||
@@ -16,20 +17,54 @@ use std::sync::Arc;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
/// Spawns a mock mistral.rs backend on a random port.
|
/// Spawns a mock neuron that serves:
|
||||||
/// Returns the base URL (e.g. "http://127.0.0.1:12345").
|
/// - GET /models (returns one loaded "test-model")
|
||||||
pub async fn spawn_mock_backend() -> String {
|
/// - GET /models/:id/endpoint (returns the inference URL)
|
||||||
let app = Router::new()
|
/// - POST /models/unload (accepts unload requests)
|
||||||
.route("/v1/chat/completions", post(mock_chat_completions))
|
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
|
||||||
.route("/v1/models", get(mock_list_models));
|
///
|
||||||
|
/// Returns the neuron base URL.
|
||||||
|
pub async fn spawn_mock_neuron() -> String {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let addr = listener.local_addr().unwrap();
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{addr}");
|
||||||
|
let inference_url = base_url.clone();
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/models", get(mock_neuron_list_models))
|
||||||
|
.route(
|
||||||
|
"/models/{model_id}/endpoint",
|
||||||
|
get(move |Path(_model_id): Path<String>| {
|
||||||
|
let url = inference_url.clone();
|
||||||
|
async move { Json(json!({"url": url})) }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/models/unload",
|
||||||
|
post(|Json(_body): Json<Value>| async { Json(json!({"status": "unloaded"})) }),
|
||||||
|
)
|
||||||
|
.route("/v1/chat/completions", post(mock_chat_completions))
|
||||||
|
.route("/v1/responses", post(mock_responses))
|
||||||
|
.route("/v1/models", get(mock_v1_models));
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
format!("http://{addr}")
|
base_url
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn mock_neuron_list_models() -> Json<Value> {
|
||||||
|
Json(json!([
|
||||||
|
{"id": "test-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
||||||
|
]))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn mock_v1_models() -> Json<Value> {
|
||||||
|
Json(json!({
|
||||||
|
"object": "list",
|
||||||
|
"data": [{"id": "test-model", "object": "model", "status": "loaded"}]
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
|
async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
|
||||||
@@ -59,21 +94,55 @@ async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn mock_list_models() -> Json<Value> {
|
async fn mock_responses(Json(body): Json<Value>) -> Json<Value> {
|
||||||
|
let model = body
|
||||||
|
.get("model")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("unknown");
|
||||||
|
// Echo the model field back and synthesise a tiny ResponsesResponse.
|
||||||
|
// Mirrors the shape neuron's /v1/responses handler emits so the
|
||||||
|
// gateway test only needs to assert the proxy round-tripped it.
|
||||||
Json(json!({
|
Json(json!({
|
||||||
"object": "list",
|
"id": "resp-test-001",
|
||||||
"data": [{
|
"object": "response",
|
||||||
"id": "test-model",
|
"created_at": 1700000000_u64,
|
||||||
"object": "model",
|
"status": "completed",
|
||||||
"status": "loaded"
|
"model": model,
|
||||||
}]
|
"output": [{
|
||||||
|
"type": "message",
|
||||||
|
"id": "msg-test-001",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{
|
||||||
|
"type": "output_text",
|
||||||
|
"text": "Hello from mock backend",
|
||||||
|
"annotations": []
|
||||||
|
}],
|
||||||
|
"status": "completed"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 5,
|
||||||
|
"output_tokens": 5,
|
||||||
|
"total_tokens": 10
|
||||||
|
}
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spawns a mock mistral.rs backend that returns SSE streaming responses.
|
/// Spawns a mock neuron that returns SSE streaming responses for chat completions.
|
||||||
/// Each chunk is delayed by `chunk_delay` to prove the proxy streams incrementally.
|
pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Duration) -> String {
|
||||||
pub async fn spawn_streaming_mock_backend(chunk_count: usize, chunk_delay: Duration) -> String {
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{addr}");
|
||||||
|
let inference_url = base_url.clone();
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
.route("/models", get(mock_neuron_list_models))
|
||||||
|
.route(
|
||||||
|
"/models/{model_id}/endpoint",
|
||||||
|
get(move |Path(_model_id): Path<String>| {
|
||||||
|
let url = inference_url.clone();
|
||||||
|
async move { Json(json!({"url": url})) }
|
||||||
|
}),
|
||||||
|
)
|
||||||
.route(
|
.route(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
post(move |Json(body): Json<Value>| async move {
|
post(move |Json(body): Json<Value>| async move {
|
||||||
@@ -118,40 +187,85 @@ pub async fn spawn_streaming_mock_backend(chunk_count: usize, chunk_delay: Durat
|
|||||||
.body(Body::from_stream(stream))
|
.body(Body::from_stream(stream))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}),
|
}),
|
||||||
)
|
);
|
||||||
.route("/v1/models", get(mock_list_models));
|
|
||||||
|
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let addr = listener.local_addr().unwrap();
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
format!("http://{addr}")
|
base_url
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spawns a mock backend with a custom `/v1/models` response.
|
/// Spawns a mock neuron with a custom models list.
|
||||||
pub async fn spawn_mock_backend_with_models(models_response: Value) -> String {
|
pub async fn spawn_mock_neuron_with_models(models_response: Value) -> String {
|
||||||
|
spawn_mock_neuron_with_models_and_health(models_response, default_health_response()).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default `/health` response used by mocks that don't care about the
|
||||||
|
/// activation field — empty devices, no in-flight pre-warm, state=ready.
|
||||||
|
pub fn default_health_response() -> Value {
|
||||||
|
json!({
|
||||||
|
"uptime_secs": 0,
|
||||||
|
"devices": [],
|
||||||
|
"activation": {
|
||||||
|
"state": "ready",
|
||||||
|
"pending": [],
|
||||||
|
"in_progress": null,
|
||||||
|
"completed": [],
|
||||||
|
"failed": []
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Variant of `spawn_mock_neuron_with_models` that also serves a
|
||||||
|
/// `/health` body. Used by tests that drive the gateway's activation
|
||||||
|
/// surface (poller reading /health, /v1/models synthesising Loading
|
||||||
|
/// locations from in_progress / pending).
|
||||||
|
pub async fn spawn_mock_neuron_with_models_and_health(
|
||||||
|
models_response: Value,
|
||||||
|
health_response: Value,
|
||||||
|
) -> String {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{addr}");
|
||||||
|
let inference_url = base_url.clone();
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/v1/chat/completions", post(mock_chat_completions))
|
|
||||||
.route(
|
.route(
|
||||||
"/v1/models",
|
"/models",
|
||||||
get(move || {
|
get(move || {
|
||||||
let resp = models_response.clone();
|
let resp = models_response.clone();
|
||||||
async move { Json(resp) }
|
async move { Json(resp) }
|
||||||
}),
|
}),
|
||||||
);
|
)
|
||||||
|
.route(
|
||||||
|
"/health",
|
||||||
|
get(move || {
|
||||||
|
let resp = health_response.clone();
|
||||||
|
async move { Json(resp) }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/models/{model_id}/endpoint",
|
||||||
|
get(move |Path(_model_id): Path<String>| {
|
||||||
|
let url = inference_url.clone();
|
||||||
|
async move { Json(json!({"url": url})) }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/models/unload",
|
||||||
|
post(|Json(_body): Json<Value>| async { Json(json!({"status": "unloaded"})) }),
|
||||||
|
)
|
||||||
|
.route("/v1/chat/completions", post(mock_chat_completions));
|
||||||
|
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let addr = listener.local_addr().unwrap();
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
format!("http://{addr}")
|
base_url
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spawns the cortex gateway with a single node pointing at `mock_url`.
|
/// Spawns the cortex gateway with a single neuron pointing at `mock_url`.
|
||||||
/// The node is pre-seeded as healthy with one loaded model ("test-model").
|
/// The node is pre-seeded as healthy with one loaded model ("test-model").
|
||||||
/// Returns the gateway's base URL.
|
/// Returns the gateway's base URL.
|
||||||
pub async fn spawn_gateway(mock_url: &str) -> String {
|
pub async fn spawn_gateway(mock_url: &str) -> String {
|
||||||
@@ -159,8 +273,7 @@ pub async fn spawn_gateway(mock_url: &str) -> String {
|
|||||||
url
|
url
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Like `spawn_gateway` but also returns the shared `CortexState` so tests
|
/// Like `spawn_gateway` but also returns the shared `CortexState`.
|
||||||
/// can call `poll_once` or inspect state directly.
|
|
||||||
pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, String) {
|
pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, String) {
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
gateway: GatewaySettings {
|
gateway: GatewaySettings {
|
||||||
@@ -171,18 +284,16 @@ pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, Stri
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "mock-node".into(),
|
name: "mock-node".into(),
|
||||||
endpoint: mock_url.to_string(),
|
endpoint: mock_url.to_string(),
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
// Seed the node as healthy with a loaded model.
|
// Seed the node as healthy with a loaded model.
|
||||||
// (Bypasses the poller, which is not running in tests.)
|
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
let node = nodes.get_mut("mock-node").expect("node must exist");
|
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||||
|
|||||||
@@ -2,15 +2,16 @@ mod common;
|
|||||||
|
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use cortex_core::config::{
|
use cortex_core::config::{
|
||||||
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NodeConfig,
|
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
|
||||||
};
|
};
|
||||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||||
use cortex_gateway::state::CortexState;
|
use cortex_gateway::state::CortexState;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Spawn a mock backend that accepts `/v1/models/unload` and records the call.
|
/// Spawn a mock neuron that accepts `/models/unload` and records unload calls.
|
||||||
async fn spawn_eviction_mock() -> (String, Arc<tokio::sync::Mutex<Vec<String>>>) {
|
async fn spawn_eviction_mock() -> (String, Arc<tokio::sync::Mutex<Vec<String>>>) {
|
||||||
|
use axum::extract::Path;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -18,9 +19,14 @@ async fn spawn_eviction_mock() -> (String, Arc<tokio::sync::Mutex<Vec<String>>>)
|
|||||||
let unloaded: Arc<tokio::sync::Mutex<Vec<String>>> = Arc::new(tokio::sync::Mutex::new(vec![]));
|
let unloaded: Arc<tokio::sync::Mutex<Vec<String>>> = Arc::new(tokio::sync::Mutex::new(vec![]));
|
||||||
let unloaded_clone = Arc::clone(&unloaded);
|
let unloaded_clone = Arc::clone(&unloaded);
|
||||||
|
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let base_url = format!("http://{addr}");
|
||||||
|
let inference_url = base_url.clone();
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route(
|
.route(
|
||||||
"/v1/models/unload",
|
"/models/unload",
|
||||||
post(move |Json(body): Json<Value>| {
|
post(move |Json(body): Json<Value>| {
|
||||||
let unloaded = Arc::clone(&unloaded_clone);
|
let unloaded = Arc::clone(&unloaded_clone);
|
||||||
async move {
|
async move {
|
||||||
@@ -30,30 +36,27 @@ async fn spawn_eviction_mock() -> (String, Arc<tokio::sync::Mutex<Vec<String>>>)
|
|||||||
.unwrap_or("")
|
.unwrap_or("")
|
||||||
.to_string();
|
.to_string();
|
||||||
unloaded.lock().await.push(model_id);
|
unloaded.lock().await.push(model_id);
|
||||||
Json(json!({"status": "ok"}))
|
Json(json!({"status": "unloaded"}))
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
.route("/models", get(|| async { Json(json!([])) }))
|
||||||
.route(
|
.route(
|
||||||
"/v1/models",
|
"/models/{model_id}/endpoint",
|
||||||
get(|| async {
|
get(move |Path(_model_id): Path<String>| {
|
||||||
Json(json!({
|
let url = inference_url.clone();
|
||||||
"object": "list",
|
async move { Json(json!({"url": url})) }
|
||||||
"data": []
|
|
||||||
}))
|
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let addr = listener.local_addr().unwrap();
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
(format!("http://{addr}"), unloaded)
|
(base_url, unloaded)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_fleet(endpoint: &str, pinned: Vec<String>, defrag_after: u32) -> Arc<CortexState> {
|
fn make_fleet(endpoint: &str, defrag_after: u32) -> Arc<CortexState> {
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
gateway: GatewaySettings {
|
gateway: GatewaySettings {
|
||||||
listen: "127.0.0.1:0".into(),
|
listen: "127.0.0.1:0".into(),
|
||||||
@@ -63,12 +66,11 @@ fn make_fleet(endpoint: &str, pinned: Vec<String>, defrag_after: u32) -> Arc<Cor
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: defrag_after,
|
defrag_after_cycles: defrag_after,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "gpu-node".into(),
|
name: "gpu-node".into(),
|
||||||
endpoint: endpoint.to_string(),
|
endpoint: endpoint.to_string(),
|
||||||
vram_mb: 24000,
|
|
||||||
pinned,
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
Arc::new(CortexState::from_config(&config))
|
Arc::new(CortexState::from_config(&config))
|
||||||
}
|
}
|
||||||
@@ -76,9 +78,8 @@ fn make_fleet(endpoint: &str, pinned: Vec<String>, defrag_after: u32) -> Arc<Cor
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_evict_lru_model() {
|
async fn test_evict_lru_model() {
|
||||||
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
||||||
let fleet = make_fleet(&mock_url, vec![], 0);
|
let fleet = make_fleet(&mock_url, 0);
|
||||||
|
|
||||||
// Seed two loaded models. "old-model" was accessed earlier than "new-model".
|
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
let node = nodes.get_mut("gpu-node").unwrap();
|
let node = nodes.get_mut("gpu-node").unwrap();
|
||||||
@@ -107,15 +108,12 @@ async fn test_evict_lru_model() {
|
|||||||
.await
|
.await
|
||||||
.expect("eviction should succeed");
|
.expect("eviction should succeed");
|
||||||
|
|
||||||
// The older model should be evicted.
|
|
||||||
assert_eq!(evicted, Some("old-model".to_string()));
|
assert_eq!(evicted, Some("old-model".to_string()));
|
||||||
|
|
||||||
// Mock received the unload call.
|
|
||||||
let calls = unloaded.lock().await;
|
let calls = unloaded.lock().await;
|
||||||
assert_eq!(calls.len(), 1);
|
assert_eq!(calls.len(), 1);
|
||||||
assert_eq!(calls[0], "old-model");
|
assert_eq!(calls[0], "old-model");
|
||||||
|
|
||||||
// Local state updated.
|
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("gpu-node").unwrap();
|
let node = nodes.get("gpu-node").unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -128,67 +126,15 @@ async fn test_evict_lru_model() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_eviction_skips_pinned_models() {
|
|
||||||
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
|
||||||
// Pin "old-model" so it can't be evicted.
|
|
||||||
let fleet = make_fleet(&mock_url, vec!["old-model".into()], 0);
|
|
||||||
|
|
||||||
{
|
|
||||||
let mut nodes = fleet.nodes.write().await;
|
|
||||||
let node = nodes.get_mut("gpu-node").unwrap();
|
|
||||||
node.healthy = true;
|
|
||||||
// old-model is pinned and older — normally it would be evicted.
|
|
||||||
node.models.insert(
|
|
||||||
"old-model".into(),
|
|
||||||
ModelEntry {
|
|
||||||
id: "old-model".into(),
|
|
||||||
status: ModelStatus::Loaded,
|
|
||||||
last_accessed: Some(Utc::now() - chrono::Duration::hours(2)),
|
|
||||||
vram_estimate_mb: Some(8000),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
node.models.insert(
|
|
||||||
"new-model".into(),
|
|
||||||
ModelEntry {
|
|
||||||
id: "new-model".into(),
|
|
||||||
status: ModelStatus::Loaded,
|
|
||||||
last_accessed: Some(Utc::now()),
|
|
||||||
vram_estimate_mb: Some(8000),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let evicted = cortex_gateway::evictor::evict_lru_on_node(&fleet, "gpu-node")
|
|
||||||
.await
|
|
||||||
.expect("eviction should succeed");
|
|
||||||
|
|
||||||
// new-model is evicted instead because old-model is pinned.
|
|
||||||
assert_eq!(evicted, Some("new-model".to_string()));
|
|
||||||
|
|
||||||
let calls = unloaded.lock().await;
|
|
||||||
assert_eq!(calls[0], "new-model");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_eviction_nothing_to_evict() {
|
async fn test_eviction_nothing_to_evict() {
|
||||||
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
let (mock_url, unloaded) = spawn_eviction_mock().await;
|
||||||
// Pin the only model.
|
let fleet = make_fleet(&mock_url, 0);
|
||||||
let fleet = make_fleet(&mock_url, vec!["only-model".into()], 0);
|
|
||||||
|
|
||||||
|
// No models at all.
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
let node = nodes.get_mut("gpu-node").unwrap();
|
nodes.get_mut("gpu-node").unwrap().healthy = true;
|
||||||
node.healthy = true;
|
|
||||||
node.models.insert(
|
|
||||||
"only-model".into(),
|
|
||||||
ModelEntry {
|
|
||||||
id: "only-model".into(),
|
|
||||||
status: ModelStatus::Loaded,
|
|
||||||
last_accessed: None,
|
|
||||||
vram_estimate_mb: Some(8000),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let evicted = cortex_gateway::evictor::evict_lru_on_node(&fleet, "gpu-node")
|
let evicted = cortex_gateway::evictor::evict_lru_on_node(&fleet, "gpu-node")
|
||||||
@@ -196,8 +142,6 @@ async fn test_eviction_nothing_to_evict() {
|
|||||||
.expect("eviction should succeed");
|
.expect("eviction should succeed");
|
||||||
|
|
||||||
assert_eq!(evicted, None);
|
assert_eq!(evicted, None);
|
||||||
|
|
||||||
// No unload call made.
|
|
||||||
let calls = unloaded.lock().await;
|
let calls = unloaded.lock().await;
|
||||||
assert!(calls.is_empty());
|
assert!(calls.is_empty());
|
||||||
}
|
}
|
||||||
@@ -205,7 +149,7 @@ async fn test_eviction_nothing_to_evict() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_eviction_increments_lifecycle_cycles() {
|
async fn test_eviction_increments_lifecycle_cycles() {
|
||||||
let (mock_url, _) = spawn_eviction_mock().await;
|
let (mock_url, _) = spawn_eviction_mock().await;
|
||||||
let fleet = make_fleet(&mock_url, vec![], 0);
|
let fleet = make_fleet(&mock_url, 0);
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
@@ -233,10 +177,9 @@ async fn test_eviction_increments_lifecycle_cycles() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_last_accessed_updated_on_request() {
|
async fn test_last_accessed_updated_on_request() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let (fleet, gw_url) = common::spawn_gateway_with_state(&mock_url).await;
|
let (fleet, gw_url) = common::spawn_gateway_with_state(&mock_url).await;
|
||||||
|
|
||||||
// Verify last_accessed is None initially.
|
|
||||||
{
|
{
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("mock-node").unwrap();
|
let node = nodes.get("mock-node").unwrap();
|
||||||
@@ -249,7 +192,6 @@ async fn test_last_accessed_updated_on_request() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a request.
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
client
|
client
|
||||||
.post(format!("{gw_url}/v1/chat/completions"))
|
.post(format!("{gw_url}/v1/chat/completions"))
|
||||||
@@ -262,7 +204,6 @@ async fn test_last_accessed_updated_on_request() {
|
|||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
|
|
||||||
// Verify last_accessed is now set.
|
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("mock-node").unwrap();
|
let node = nodes.get("mock-node").unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
|
|||||||
@@ -4,21 +4,17 @@ use serde_json::json;
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_metrics_emitted_after_proxy() {
|
async fn test_metrics_emitted_after_proxy() {
|
||||||
// Install a test recorder (no HTTP listener, renders to string).
|
|
||||||
// This sets the global recorder, so only one test can do this.
|
|
||||||
let handle = cortex_gateway::metrics::install_test_recorder().expect("recorder should install");
|
let handle = cortex_gateway::metrics::install_test_recorder().expect("recorder should install");
|
||||||
|
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
// Verify no request metrics yet.
|
|
||||||
let before = handle.render();
|
let before = handle.render();
|
||||||
assert!(
|
assert!(
|
||||||
!before.contains("cortex_requests_total"),
|
!before.contains("cortex_requests_total"),
|
||||||
"no request metrics before any requests"
|
"no request metrics before any requests"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Make a successful request.
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp = client
|
let resp = client
|
||||||
.post(format!("{gw_url}/v1/chat/completions"))
|
.post(format!("{gw_url}/v1/chat/completions"))
|
||||||
@@ -31,10 +27,8 @@ async fn test_metrics_emitted_after_proxy() {
|
|||||||
.await
|
.await
|
||||||
.expect("request should succeed");
|
.expect("request should succeed");
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
// Consume the response body to ensure the proxy completes.
|
|
||||||
let _body: serde_json::Value = resp.json().await.unwrap();
|
let _body: serde_json::Value = resp.json().await.unwrap();
|
||||||
|
|
||||||
// Check metrics were emitted.
|
|
||||||
let after = handle.render();
|
let after = handle.render();
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
@@ -45,7 +39,6 @@ async fn test_metrics_emitted_after_proxy() {
|
|||||||
after.contains("cortex_request_duration_seconds"),
|
after.contains("cortex_request_duration_seconds"),
|
||||||
"cortex_request_duration_seconds should be present.\nMetrics:\n{after}"
|
"cortex_request_duration_seconds should be present.\nMetrics:\n{after}"
|
||||||
);
|
);
|
||||||
// Should NOT have error or cold start counters for this request.
|
|
||||||
assert!(
|
assert!(
|
||||||
!after.contains("cortex_request_errors_total"),
|
!after.contains("cortex_request_errors_total"),
|
||||||
"no errors expected for a successful request"
|
"no errors expected for a successful request"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use cortex_core::config::{
|
use cortex_core::config::{
|
||||||
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NodeConfig,
|
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
|
||||||
};
|
};
|
||||||
use cortex_core::node::ModelStatus;
|
use cortex_core::node::ModelStatus;
|
||||||
use cortex_gateway::state::CortexState;
|
use cortex_gateway::state::CortexState;
|
||||||
@@ -10,14 +10,11 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_discovers_models() {
|
async fn test_poller_discovers_models() {
|
||||||
// Mock backend reports 2 models: one loaded, one unloaded.
|
// Mock neuron reports 2 models via /models endpoint (neuron format).
|
||||||
let mock_url = common::spawn_mock_backend_with_models(json!({
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
"object": "list",
|
{"id": "model-a", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
||||||
"data": [
|
{"id": "model-b", "harness": "candle", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
||||||
{ "id": "model-a", "object": "model", "status": "loaded" },
|
]))
|
||||||
{ "id": "model-b", "object": "model", "status": "unloaded" }
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
@@ -29,17 +26,15 @@ async fn test_poller_discovers_models() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "test-node".into(),
|
name: "test-node".into(),
|
||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
// Before polling: node is unhealthy, no models.
|
|
||||||
{
|
{
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("test-node").unwrap();
|
let node = nodes.get("test-node").unwrap();
|
||||||
@@ -47,10 +42,8 @@ async fn test_poller_discovers_models() {
|
|||||||
assert!(node.models.is_empty());
|
assert!(node.models.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Poll once.
|
|
||||||
cortex_gateway::poller::poll_once(&fleet).await;
|
cortex_gateway::poller::poll_once(&fleet).await;
|
||||||
|
|
||||||
// After polling: node is healthy, both models discovered with correct status.
|
|
||||||
{
|
{
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
let node = nodes.get("test-node").unwrap();
|
let node = nodes.get("test-node").unwrap();
|
||||||
@@ -69,14 +62,10 @@ async fn test_poller_discovers_models() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_updates_gateway_models_endpoint() {
|
async fn test_poller_updates_gateway_models_endpoint() {
|
||||||
// Mock backend with 2 models.
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
let mock_url = common::spawn_mock_backend_with_models(json!({
|
{"id": "model-x", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||||
"object": "list",
|
{"id": "model-y", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
||||||
"data": [
|
]))
|
||||||
{ "id": "model-x", "object": "model", "status": "loaded" },
|
|
||||||
{ "id": "model-y", "object": "model", "status": "loaded" }
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
@@ -88,20 +77,16 @@ async fn test_poller_updates_gateway_models_endpoint() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "poll-node".into(),
|
name: "poll-node".into(),
|
||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
// Poll to discover models and mark node healthy.
|
|
||||||
cortex_gateway::poller::poll_once(&fleet).await;
|
cortex_gateway::poller::poll_once(&fleet).await;
|
||||||
|
|
||||||
// Start gateway with the polled state.
|
|
||||||
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let addr = listener.local_addr().unwrap();
|
let addr = listener.local_addr().unwrap();
|
||||||
@@ -109,7 +94,6 @@ async fn test_poller_updates_gateway_models_endpoint() {
|
|||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
});
|
});
|
||||||
|
|
||||||
// Query /v1/models on the gateway.
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp = client
|
let resp = client
|
||||||
.get(format!("http://{addr}/v1/models"))
|
.get(format!("http://{addr}/v1/models"))
|
||||||
@@ -127,7 +111,6 @@ async fn test_poller_updates_gateway_models_endpoint() {
|
|||||||
assert!(ids.contains(&"model-x"));
|
assert!(ids.contains(&"model-x"));
|
||||||
assert!(ids.contains(&"model-y"));
|
assert!(ids.contains(&"model-y"));
|
||||||
|
|
||||||
// Verify node attribution in locations.
|
|
||||||
for model in data {
|
for model in data {
|
||||||
let locations = model["locations"].as_array().expect("locations array");
|
let locations = model["locations"].as_array().expect("locations array");
|
||||||
assert_eq!(locations.len(), 1);
|
assert_eq!(locations.len(), 1);
|
||||||
@@ -146,17 +129,15 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "dead-node".into(),
|
name: "dead-node".into(),
|
||||||
endpoint: "http://127.0.0.1:1".into(), // unreachable
|
endpoint: "http://127.0.0.1:1".into(),
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
|
||||||
// Manually mark healthy to verify poller flips it.
|
|
||||||
{
|
{
|
||||||
let mut nodes = fleet.nodes.write().await;
|
let mut nodes = fleet.nodes.write().await;
|
||||||
nodes.get_mut("dead-node").unwrap().healthy = true;
|
nodes.get_mut("dead-node").unwrap().healthy = true;
|
||||||
@@ -170,14 +151,10 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_removes_stale_models() {
|
async fn test_poller_removes_stale_models() {
|
||||||
// Start with a mock that reports 2 models.
|
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
let mock_url = common::spawn_mock_backend_with_models(json!({
|
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||||
"object": "list",
|
{"id": "drop-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||||
"data": [
|
]))
|
||||||
{ "id": "keep-me", "object": "model", "status": "loaded" },
|
|
||||||
{ "id": "drop-me", "object": "model", "status": "loaded" }
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
@@ -189,35 +166,27 @@ async fn test_poller_removes_stale_models() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "test-node".into(),
|
name: "test-node".into(),
|
||||||
endpoint: mock_url,
|
endpoint: mock_url,
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet = Arc::new(CortexState::from_config(&config));
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
cortex_gateway::poller::poll_once(&fleet).await;
|
cortex_gateway::poller::poll_once(&fleet).await;
|
||||||
|
|
||||||
// Verify both models exist.
|
|
||||||
{
|
{
|
||||||
let nodes = fleet.nodes.read().await;
|
let nodes = fleet.nodes.read().await;
|
||||||
assert_eq!(nodes.get("test-node").unwrap().models.len(), 2);
|
assert_eq!(nodes.get("test-node").unwrap().models.len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now spin up a new mock that only reports one model, and re-point the node.
|
// New mock with only one model.
|
||||||
let new_mock_url = common::spawn_mock_backend_with_models(json!({
|
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||||
"object": "list",
|
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||||
"data": [
|
]))
|
||||||
{ "id": "keep-me", "object": "model", "status": "loaded" }
|
|
||||||
]
|
|
||||||
}))
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Update the node endpoint to point at the new mock.
|
|
||||||
// We can't change node_configs (they're immutable), so instead we'll
|
|
||||||
// create a new fleet with the updated endpoint and poll that.
|
|
||||||
let config2 = GatewayConfig {
|
let config2 = GatewayConfig {
|
||||||
gateway: GatewaySettings {
|
gateway: GatewaySettings {
|
||||||
listen: "127.0.0.1:0".into(),
|
listen: "127.0.0.1:0".into(),
|
||||||
@@ -227,17 +196,16 @@ async fn test_poller_removes_stale_models() {
|
|||||||
strategy: EvictionStrategy::Lru,
|
strategy: EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![NodeConfig {
|
neurons: vec![NeuronEndpoint {
|
||||||
name: "test-node".into(),
|
name: "test-node".into(),
|
||||||
endpoint: new_mock_url,
|
endpoint: new_mock_url,
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fleet2 = Arc::new(CortexState::from_config(&config2));
|
let fleet2 = Arc::new(CortexState::from_config(&config2));
|
||||||
|
|
||||||
// Seed the stale model so we can verify it gets removed.
|
// Seed stale model.
|
||||||
{
|
{
|
||||||
let mut nodes = fleet2.nodes.write().await;
|
let mut nodes = fleet2.nodes.write().await;
|
||||||
let node = nodes.get_mut("test-node").unwrap();
|
let node = nodes.get_mut("test-node").unwrap();
|
||||||
@@ -269,3 +237,58 @@ async fn test_poller_removes_stale_models() {
|
|||||||
assert!(node.models.contains_key("keep-me"));
|
assert!(node.models.contains_key("keep-me"));
|
||||||
assert!(!node.models.contains_key("drop-me"));
|
assert!(!node.models.contains_key("drop-me"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_poller_captures_activation_from_health() {
|
||||||
|
// Mock neuron is mid-prewarm: /models reports nothing (the loading
|
||||||
|
// model hasn't been inserted into the harness map yet), but
|
||||||
|
// /health's activation says model-x is in_progress and model-y is
|
||||||
|
// queued behind it.
|
||||||
|
let mock_url = common::spawn_mock_neuron_with_models_and_health(
|
||||||
|
json!([]),
|
||||||
|
json!({
|
||||||
|
"uptime_secs": 30,
|
||||||
|
"devices": [],
|
||||||
|
"activation": {
|
||||||
|
"state": "pre_warming",
|
||||||
|
"pending": ["Qwen/model-y"],
|
||||||
|
"in_progress": "Qwen/model-x",
|
||||||
|
"completed": [],
|
||||||
|
"failed": []
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let config = GatewayConfig {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "127.0.0.1:0".into(),
|
||||||
|
metrics_listen: "127.0.0.1:0".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 0,
|
||||||
|
},
|
||||||
|
neurons: vec![NeuronEndpoint {
|
||||||
|
name: "prewarm-node".into(),
|
||||||
|
endpoint: mock_url,
|
||||||
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
cortex_gateway::poller::poll_once(&fleet).await;
|
||||||
|
|
||||||
|
let nodes = fleet.nodes.read().await;
|
||||||
|
let node = nodes.get("prewarm-node").unwrap();
|
||||||
|
assert!(node.healthy);
|
||||||
|
// /models was empty — no entries in the per-node model map.
|
||||||
|
assert!(node.models.is_empty());
|
||||||
|
// But /health's activation should be captured.
|
||||||
|
let activation = node
|
||||||
|
.activation
|
||||||
|
.as_ref()
|
||||||
|
.expect("activation should be populated after /health poll");
|
||||||
|
assert_eq!(activation.in_progress.as_deref(), Some("Qwen/model-x"));
|
||||||
|
assert_eq!(activation.pending, vec!["Qwen/model-y".to_string()]);
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use serde_json::json;
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_chat_completion_proxy() {
|
async fn test_chat_completion_proxy() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -33,7 +33,7 @@ async fn test_chat_completion_proxy() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_health_endpoint() {
|
async fn test_health_endpoint() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -53,7 +53,7 @@ async fn test_health_endpoint() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_list_models() {
|
async fn test_list_models() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -75,7 +75,7 @@ async fn test_list_models() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_model_not_found() {
|
async fn test_model_not_found() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -112,12 +112,11 @@ async fn test_no_healthy_nodes() {
|
|||||||
strategy: cortex_core::config::EvictionStrategy::Lru,
|
strategy: cortex_core::config::EvictionStrategy::Lru,
|
||||||
defrag_after_cycles: 0,
|
defrag_after_cycles: 0,
|
||||||
},
|
},
|
||||||
nodes: vec![cortex_core::config::NodeConfig {
|
neurons: vec![cortex_core::config::NeuronEndpoint {
|
||||||
name: "dead-node".into(),
|
name: "dead-node".into(),
|
||||||
endpoint: "http://127.0.0.1:1".into(),
|
endpoint: "http://127.0.0.1:1".into(),
|
||||||
vram_mb: 24000,
|
|
||||||
pinned: vec![],
|
|
||||||
}],
|
}],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
};
|
};
|
||||||
let fleet = std::sync::Arc::new(cortex_gateway::state::CortexState::from_config(&config));
|
let fleet = std::sync::Arc::new(cortex_gateway::state::CortexState::from_config(&config));
|
||||||
|
|
||||||
@@ -153,7 +152,7 @@ async fn test_no_healthy_nodes() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_missing_model_field() {
|
async fn test_missing_model_field() {
|
||||||
let mock_url = common::spawn_mock_backend().await;
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|||||||
91
crates/cortex-gateway/tests/responses.rs
Normal file
91
crates/cortex-gateway/tests/responses.rs
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
//! Integration tests for the `/v1/responses` proxy route.
|
||||||
|
//!
|
||||||
|
//! The gateway forwards the request body to whichever neuron has the
|
||||||
|
//! model loaded. These tests exercise the routing decision (200 on a
|
||||||
|
//! known model, 404 on an unknown model, 400 on a missing model
|
||||||
|
//! field) and confirm the response body round-trips verbatim.
|
||||||
|
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
/// Happy path: gateway routes a `/v1/responses` request to the neuron
|
||||||
|
/// that has the model loaded, and the neuron's response body
|
||||||
|
/// arrives at the client unchanged.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_responses_proxy() {
|
||||||
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.post(format!("{gw_url}/v1/responses"))
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json(&json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"input": "Hi"
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("request should succeed");
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), 200);
|
||||||
|
|
||||||
|
let body: serde_json::Value = resp.json().await.expect("valid JSON response");
|
||||||
|
assert_eq!(body["id"], "resp-test-001");
|
||||||
|
assert_eq!(body["object"], "response");
|
||||||
|
assert_eq!(body["model"], "test-model");
|
||||||
|
assert_eq!(body["status"], "completed");
|
||||||
|
assert_eq!(
|
||||||
|
body["output"][0]["content"][0]["text"],
|
||||||
|
"Hello from mock backend"
|
||||||
|
);
|
||||||
|
// Usage shape is the Responses-specific (input/output_tokens),
|
||||||
|
// not the chat-completions one (prompt/completion_tokens). Asserts
|
||||||
|
// the proxy didn't accidentally route through the wrong handler.
|
||||||
|
assert_eq!(body["usage"]["total_tokens"], 10);
|
||||||
|
assert!(body["usage"].get("input_tokens").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A request that targets a model not present in the catalogue gets
|
||||||
|
/// 404 from the router. This matches the chat-completions handler's
|
||||||
|
/// behaviour — same error path, same status code, so a client can
|
||||||
|
/// share retry logic across the two routes.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_responses_model_not_found() {
|
||||||
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.post(format!("{gw_url}/v1/responses"))
|
||||||
|
.json(&json!({
|
||||||
|
"model": "not-in-catalogue",
|
||||||
|
"input": "Hi"
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 404);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A request body without a `model` field can't be routed; the
|
||||||
|
/// gateway returns 400 before reaching a backend. Same as the
|
||||||
|
/// chat-completions handler — extracted via the same `extract_model`
|
||||||
|
/// helper.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_responses_missing_model_field() {
|
||||||
|
let mock_url = common::spawn_mock_neuron().await;
|
||||||
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client
|
||||||
|
.post(format!("{gw_url}/v1/responses"))
|
||||||
|
.json(&json!({
|
||||||
|
"input": "Hi"
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(resp.status(), 400);
|
||||||
|
}
|
||||||
@@ -8,7 +8,7 @@ use std::time::{Duration, Instant};
|
|||||||
async fn test_streaming_sse_passthrough() {
|
async fn test_streaming_sse_passthrough() {
|
||||||
let chunk_count = 5;
|
let chunk_count = 5;
|
||||||
let chunk_delay = Duration::from_millis(50);
|
let chunk_delay = Duration::from_millis(50);
|
||||||
let mock_url = common::spawn_streaming_mock_backend(chunk_count, chunk_delay).await;
|
let mock_url = common::spawn_streaming_mock_neuron(chunk_count, chunk_delay).await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@@ -33,7 +33,6 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
"text/event-stream"
|
"text/event-stream"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Collect SSE chunks as they arrive, recording arrival times.
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let mut chunk_times = Vec::new();
|
let mut chunk_times = Vec::new();
|
||||||
let mut chunks = Vec::new();
|
let mut chunks = Vec::new();
|
||||||
@@ -51,32 +50,25 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify we got all content chunks plus [DONE].
|
|
||||||
assert!(
|
assert!(
|
||||||
chunks.len() >= chunk_count + 1,
|
chunks.len() > chunk_count,
|
||||||
"expected at least {} chunks (got {}): {:?}",
|
"expected more than {} chunks (got {}): {:?}",
|
||||||
chunk_count + 1,
|
chunk_count,
|
||||||
chunks.len(),
|
chunks.len(),
|
||||||
chunks,
|
chunks,
|
||||||
);
|
);
|
||||||
|
|
||||||
// The last chunk should be [DONE].
|
|
||||||
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
||||||
|
|
||||||
// Verify the content chunks contain expected tokens.
|
for (i, chunk) in chunks.iter().enumerate().take(chunk_count) {
|
||||||
for i in 0..chunk_count {
|
|
||||||
let chunk_json: serde_json::Value =
|
let chunk_json: serde_json::Value =
|
||||||
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
|
serde_json::from_str(chunk).expect("chunk should be valid JSON");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
chunk_json["choices"][0]["delta"]["content"],
|
chunk_json["choices"][0]["delta"]["content"],
|
||||||
format!("token{i}")
|
format!("token{i}")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify streaming behavior: total time should reflect incremental delivery,
|
|
||||||
// not a single batch. With 5 chunks at 50ms each + [DONE], we expect ~300ms total.
|
|
||||||
// If buffered, all chunks would arrive at once after ~300ms with no spread.
|
|
||||||
// We verify that the last chunk arrived noticeably after the first.
|
|
||||||
let first = chunk_times.first().unwrap();
|
let first = chunk_times.first().unwrap();
|
||||||
let last = chunk_times.last().unwrap();
|
let last = chunk_times.last().unwrap();
|
||||||
let spread = *last - *first;
|
let spread = *last - *first;
|
||||||
@@ -88,7 +80,7 @@ async fn test_streaming_sse_passthrough() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_streaming_done_terminator() {
|
async fn test_streaming_done_terminator() {
|
||||||
let mock_url = common::spawn_streaming_mock_backend(2, Duration::from_millis(10)).await;
|
let mock_url = common::spawn_streaming_mock_neuron(2, Duration::from_millis(10)).await;
|
||||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|||||||
48
crates/helexa-acp/Cargo.toml
Normal file
48
crates/helexa-acp/Cargo.toml
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
[package]
|
||||||
|
name = "helexa-acp"
|
||||||
|
version = "0.1.16"
|
||||||
|
edition = "2024"
|
||||||
|
license = "Apache-2.0"
|
||||||
|
repository = "https://git.lair.cafe/helexa/cortex"
|
||||||
|
description = """
|
||||||
|
Agent Client Protocol bridge for the helexa self-hosted LLM stack.
|
||||||
|
Speaks ACP to ACP-compatible editor clients (Zed, etc.) and forwards
|
||||||
|
the conversation to any OpenAI-compatible HTTP endpoint — defaulting
|
||||||
|
to cortex (helexa's reverse-proxy / fleet gateway).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This crate is intentionally self-contained — no dependencies on other
|
||||||
|
# workspace crates (cortex-core, cortex-gateway, neuron). The goal is
|
||||||
|
# a painless migration to a dedicated GitHub repo in the future if the
|
||||||
|
# project grows beyond helexa's needs. All deps are crates.io.
|
||||||
|
[dependencies]
|
||||||
|
# `unstable_session_model` flips on the SessionModelState type and the
|
||||||
|
# session/set_model RPC the model-picker dropdown in Zed needs. The
|
||||||
|
# feature is upstream-marked unstable; we accept that risk because the
|
||||||
|
# model picker is core UX and the alternative (rolling our own
|
||||||
|
# extension method) drifts further from spec each time it moves.
|
||||||
|
agent-client-protocol = { version = "0.12", features = ["unstable_session_model"] }
|
||||||
|
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "io-util", "process", "signal"] }
|
||||||
|
reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"], default-features = false }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
toml = "0.8"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
anyhow = "1"
|
||||||
|
thiserror = "2"
|
||||||
|
async-trait = "0.1"
|
||||||
|
futures = "0.3"
|
||||||
|
tokio-stream = "0.1"
|
||||||
|
tokio-util = { version = "0.7", features = ["rt"] }
|
||||||
|
eventsource-stream = "0.2"
|
||||||
|
async-stream = "0.3"
|
||||||
|
url = { version = "2", features = ["serde"] }
|
||||||
|
# Already transitively pulled via the ACP SDK; declared directly so we
|
||||||
|
# can format ISO 8601 timestamps for `SessionInfo.updated_at` in the
|
||||||
|
# session/list response.
|
||||||
|
chrono = { version = "0.4", default-features = false, features = ["std"] }
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "helexa-acp"
|
||||||
|
path = "src/main.rs"
|
||||||
546
crates/helexa-acp/README.md
Normal file
546
crates/helexa-acp/README.md
Normal file
@@ -0,0 +1,546 @@
|
|||||||
|
# helexa-acp
|
||||||
|
|
||||||
|
ACP (Agent Client Protocol) bridge for editors like
|
||||||
|
[Zed](https://zed.dev). Lets you point your editor's agent panel at
|
||||||
|
**any combination** of OpenAI-compatible, OpenAI Responses, and
|
||||||
|
Anthropic Messages endpoints — public APIs, private LAN deployments,
|
||||||
|
local Ollama / LM Studio — and switch between them per session via a
|
||||||
|
model dropdown.
|
||||||
|
|
||||||
|
The "missing ACP binary" for users who don't want to be locked into
|
||||||
|
one vendor's agent client.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌───────────────────────────────────┐
|
||||||
|
│ Zed (or any ACP editor client) │
|
||||||
|
└────────────┬──────────────────────┘
|
||||||
|
│ stdio JSON-RPC (ACP)
|
||||||
|
▼
|
||||||
|
┌─────────────────┐
|
||||||
|
│ helexa-acp │ ← one binary, multi-endpoint
|
||||||
|
└─────┬───────────┘
|
||||||
|
│ HTTP / SSE
|
||||||
|
┌────────┼─────────────┬──────────────┬──────────────┐
|
||||||
|
▼ ▼ ▼ ▼ ▼
|
||||||
|
cortex/ OpenAI Anthropic OpenRouter LM Studio
|
||||||
|
neuron Responses Messages
|
||||||
|
(self- (gpt-5,…) (Claude)
|
||||||
|
hosted)
|
||||||
|
```
|
||||||
|
|
||||||
|
## What it does
|
||||||
|
|
||||||
|
- **Speaks ACP** over stdio to editor clients (Zed today; any future
|
||||||
|
ACP client tomorrow).
|
||||||
|
- **Multi-endpoint** — one config file lists every LLM endpoint
|
||||||
|
you want available; pick one per session via the model dropdown
|
||||||
|
(`endpoint:model` selector).
|
||||||
|
- **Three wire formats**: `openai-chat` (the broadly compatible
|
||||||
|
default), `openai-responses` (newer OpenAI surface), and
|
||||||
|
`anthropic-messages` (Claude). Each is a separate provider impl
|
||||||
|
in `src/provider/`; adding a fourth (Gemini, Ollama native, …) is
|
||||||
|
one file plus a `WireApi` enum variant.
|
||||||
|
- **Built-in tools**: `read_file`, `write_file`, `edit_file`,
|
||||||
|
`list_dir`, `bash`. Permission-gated by default; the editor user
|
||||||
|
approves writes/shell per-call.
|
||||||
|
- **Three session modes**: Default (gated), Bypass Permissions
|
||||||
|
(auto-allow), and Plan (write-only-to-plan-dir, no shell).
|
||||||
|
- **Vision** — drag-drop images into the agent panel against any
|
||||||
|
vision-capable model.
|
||||||
|
- **Session resume** — multi-day conversations survive editor
|
||||||
|
restarts via on-disk transcript persistence.
|
||||||
|
- **Context compaction** — rolling history stays inside the model's
|
||||||
|
context window automatically so long sessions on small-context
|
||||||
|
local models don't fall over.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
### From source
|
||||||
|
|
||||||
|
```sh
|
||||||
|
git clone https://git.lair.cafe/helexa/cortex.git
|
||||||
|
cd cortex
|
||||||
|
cargo install --path crates/helexa-acp
|
||||||
|
# Binary lands at ~/.cargo/bin/helexa-acp
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pre-built RPM (Fedora 43)
|
||||||
|
|
||||||
|
```sh
|
||||||
|
dnf copr enable helexa/helexa
|
||||||
|
dnf install helexa-acp
|
||||||
|
```
|
||||||
|
|
||||||
|
The COPR project bundles helexa-acp alongside the cortex gateway
|
||||||
|
and helexa-neuron flavours; install only the package(s) you need.
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
The fastest path: env-var single-endpoint config.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
export HELEXA_ACP_BASE_URL=http://hanzalova.internal:31313/v1
|
||||||
|
export HELEXA_ACP_MODEL=Qwen/Qwen3.6-27B
|
||||||
|
helexa-acp # speaks ACP over stdin/stdout; not interactive
|
||||||
|
```
|
||||||
|
|
||||||
|
Then in Zed (`~/.config/zed/settings.json`):
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
"agent_servers": {
|
||||||
|
"helexa": {
|
||||||
|
"command": "helexa-acp",
|
||||||
|
"args": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Restart Zed → open the agent panel → pick "helexa" → start
|
||||||
|
chatting. Tool calls (file reads, writes, bash) prompt for
|
||||||
|
permission per-call in Default mode.
|
||||||
|
|
||||||
|
That's the minimum. The full config story below is what unlocks
|
||||||
|
the multi-endpoint dropdown.
|
||||||
|
|
||||||
|
## Multi-endpoint config
|
||||||
|
|
||||||
|
Copy `helexa-acp.example.toml` from this repo to
|
||||||
|
`$XDG_CONFIG_HOME/helexa-acp/config.toml` (typically
|
||||||
|
`~/.config/helexa-acp/config.toml`) and edit:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
default_endpoint = "helexa"
|
||||||
|
|
||||||
|
[[endpoints]]
|
||||||
|
name = "helexa"
|
||||||
|
base_url = "http://hanzalova.internal:31313/v1"
|
||||||
|
wire_api = "openai-chat"
|
||||||
|
default_model = "Qwen/Qwen3.6-27B"
|
||||||
|
max_tokens = 8192
|
||||||
|
context_window = 32768
|
||||||
|
|
||||||
|
[[endpoints]]
|
||||||
|
name = "openrouter"
|
||||||
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
|
wire_api = "openai-chat"
|
||||||
|
api_key_env = "OPENROUTER_API_KEY"
|
||||||
|
default_model = "anthropic/claude-opus-4"
|
||||||
|
|
||||||
|
[[endpoints]]
|
||||||
|
name = "anthropic"
|
||||||
|
base_url = "https://api.anthropic.com/v1"
|
||||||
|
wire_api = "anthropic-messages"
|
||||||
|
api_key_env = "ANTHROPIC_API_KEY"
|
||||||
|
default_model = "claude-opus-4"
|
||||||
|
```
|
||||||
|
|
||||||
|
Restart Zed. The model dropdown lists every model from every
|
||||||
|
configured endpoint with the `endpoint:model` selector
|
||||||
|
(`helexa:Qwen/Qwen3.6-27B`, `openrouter:anthropic/claude-opus-4`,
|
||||||
|
…). Switch mid-session; the next prompt routes to the new endpoint.
|
||||||
|
|
||||||
|
When only one endpoint is configured the prefix is dropped (model
|
||||||
|
ids appear bare).
|
||||||
|
|
||||||
|
### Selector syntax
|
||||||
|
|
||||||
|
The `model` field on every internal request is parsed as
|
||||||
|
`<endpoint>:<model>`:
|
||||||
|
|
||||||
|
- `openrouter:gpt-4o` → routes to the `openrouter` endpoint,
|
||||||
|
model `gpt-4o`.
|
||||||
|
- `helexa/large` → no colon → falls through to whichever endpoint
|
||||||
|
is named in `default_endpoint`, model `helexa/large`.
|
||||||
|
- `:gpt-5` → leading colon → also falls through to default.
|
||||||
|
|
||||||
|
## Endpoint cookbook
|
||||||
|
|
||||||
|
Copy-pasteable blocks. Mix and match.
|
||||||
|
|
||||||
|
### cortex / neuron (self-hosted)
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[endpoints]]
|
||||||
|
name = "helexa"
|
||||||
|
base_url = "http://hanzalova.internal:31313/v1"
|
||||||
|
wire_api = "openai-chat"
|
||||||
|
default_model = "Qwen/Qwen3.6-27B"
|
||||||
|
max_tokens = 8192
|
||||||
|
context_window = 32768
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `openai-responses` instead of `openai-chat` once cortex 0.1.16+
|
||||||
|
is deployed and you want the Responses API surface (vision item
|
||||||
|
shape, structured reasoning items, etc.).
|
||||||
|
|
||||||
|
### OpenAI directly
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[endpoints]]
|
||||||
|
name = "openai"
|
||||||
|
base_url = "https://api.openai.com/v1"
|
||||||
|
wire_api = "openai-responses"
|
||||||
|
api_key_env = "OPENAI_API_KEY"
|
||||||
|
default_model = "gpt-5"
|
||||||
|
```
|
||||||
|
|
||||||
|
`openai-responses` is the right choice for current OpenAI models;
|
||||||
|
`openai-chat` works against legacy GPT-3.5/4 deployments and
|
||||||
|
anything labelled "chat completions".
|
||||||
|
|
||||||
|
### Anthropic directly
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[endpoints]]
|
||||||
|
name = "anthropic"
|
||||||
|
base_url = "https://api.anthropic.com/v1"
|
||||||
|
wire_api = "anthropic-messages"
|
||||||
|
api_key_env = "ANTHROPIC_API_KEY"
|
||||||
|
default_model = "claude-opus-4"
|
||||||
|
```
|
||||||
|
|
||||||
|
helexa-acp sends `x-api-key` + `anthropic-version: 2023-06-01`
|
||||||
|
automatically. The `api_key_env` indirection keeps your key out of
|
||||||
|
the config file.
|
||||||
|
|
||||||
|
### OpenRouter (multi-vendor proxy)
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[endpoints]]
|
||||||
|
name = "openrouter"
|
||||||
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
|
wire_api = "openai-chat"
|
||||||
|
api_key_env = "OPENROUTER_API_KEY"
|
||||||
|
default_model = "anthropic/claude-opus-4"
|
||||||
|
```
|
||||||
|
|
||||||
|
OpenRouter speaks OpenAI-compat for every model it fronts, so
|
||||||
|
`openai-chat` is the right wire format regardless of the
|
||||||
|
underlying vendor.
|
||||||
|
|
||||||
|
### LM Studio (local)
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[endpoints]]
|
||||||
|
name = "lmstudio"
|
||||||
|
base_url = "http://localhost:1234/v1"
|
||||||
|
wire_api = "openai-chat"
|
||||||
|
default_model = "auto"
|
||||||
|
```
|
||||||
|
|
||||||
|
LM Studio's "auto" model id picks whatever's loaded. Same shape
|
||||||
|
works for Ollama in compat mode (`http://localhost:11434/v1`) and
|
||||||
|
vLLM.
|
||||||
|
|
||||||
|
### Multiple cortex deployments
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[endpoints]]
|
||||||
|
name = "lan"
|
||||||
|
base_url = "http://hanzalova.internal:31313/v1"
|
||||||
|
wire_api = "openai-chat"
|
||||||
|
default_model = "Qwen/Qwen3.6-27B"
|
||||||
|
|
||||||
|
[[endpoints]]
|
||||||
|
name = "cloud"
|
||||||
|
base_url = "https://cortex.example.com/v1"
|
||||||
|
wire_api = "openai-chat"
|
||||||
|
api_key_env = "CLOUD_CORTEX_KEY"
|
||||||
|
default_model = "Qwen/Qwen3-VL-8B"
|
||||||
|
```
|
||||||
|
|
||||||
|
Use the `endpoint:model` selector to switch between them mid-session.
|
||||||
|
|
||||||
|
## Zed setup
|
||||||
|
|
||||||
|
`~/.config/zed/settings.json`:
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
"agent_servers": {
|
||||||
|
"helexa": {
|
||||||
|
"command": "helexa-acp"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Optional environment overrides for the binary:
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
"agent_servers": {
|
||||||
|
"helexa": {
|
||||||
|
"command": "helexa-acp",
|
||||||
|
"env": {
|
||||||
|
"HELEXA_ACP_LOG_FILE": "/tmp/helexa-acp.log",
|
||||||
|
"RUST_LOG": "helexa_acp=debug"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`HELEXA_ACP_LOG_FILE` is the one you actually want — Zed doesn't
|
||||||
|
surface the agent's stderr, so without that env var debug output is
|
||||||
|
invisible. Point it at a file you can `tail -f`.
|
||||||
|
|
||||||
|
After restarting Zed: ⌘+? (or wherever your "Open Agent Panel"
|
||||||
|
binding is) → select "helexa" → the model dropdown populates from
|
||||||
|
your config → start prompting.
|
||||||
|
|
||||||
|
## Modes
|
||||||
|
|
||||||
|
Three session modes ship; the user picks via Zed's mode dropdown
|
||||||
|
on the agent panel.
|
||||||
|
|
||||||
|
| Mode | Reads | Writes | Bash | Permission prompts |
|
||||||
|
|------|-------|--------|------|--------------------|
|
||||||
|
| **Default** | ✓ | with prompt | with prompt | per call |
|
||||||
|
| **Bypass Permissions** | ✓ | ✓ | ✓ | never |
|
||||||
|
| **Plan** | ✓ | only into plan dir | disabled | never (plan-dir writes auto-allow) |
|
||||||
|
|
||||||
|
### Default
|
||||||
|
|
||||||
|
Reads are always allowed (`read_file`, `list_dir` are
|
||||||
|
unrestricted). Writes and shell commands prompt the user before
|
||||||
|
running. The intended baseline for any session where the agent
|
||||||
|
might do something you'd rather review first.
|
||||||
|
|
||||||
|
### Bypass Permissions
|
||||||
|
|
||||||
|
Auto-allow every tool call. Use for agentic loops you trust — bulk
|
||||||
|
edits across many files, scripted workflows, prepared session
|
||||||
|
templates. Never for code the agent hasn't seen before.
|
||||||
|
|
||||||
|
### Plan
|
||||||
|
|
||||||
|
The "draft an implementation plan before you write code" mode.
|
||||||
|
Available tools:
|
||||||
|
|
||||||
|
- `read_file`, `list_dir`: unrestricted (read the codebase).
|
||||||
|
- `write_file`, `edit_file`: allowed *only* under
|
||||||
|
`$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. Any path
|
||||||
|
outside that returns "plan mode: writes are restricted to …"
|
||||||
|
back to the model so it self-corrects.
|
||||||
|
- `bash`: disabled outright. Returns "plan mode: shell execution
|
||||||
|
is disabled" if attempted.
|
||||||
|
|
||||||
|
When the plan is complete, the model presents a 3-option menu:
|
||||||
|
|
||||||
|
1. **Bypass Permissions** — implement the plan now, no prompts.
|
||||||
|
2. **Default** — implement now with per-tool prompts.
|
||||||
|
3. **Plan** (stay here) — refine the plan with more guidance.
|
||||||
|
|
||||||
|
Switch the mode dropdown to your preference and reply to proceed.
|
||||||
|
|
||||||
|
## Tools
|
||||||
|
|
||||||
|
Five tools, defined in `src/tools.rs`:
|
||||||
|
|
||||||
|
| Tool | Args | Gated in Default? |
|
||||||
|
|------|------|-------------------|
|
||||||
|
| `read_file` | `path`, `line?`, `limit?` | no |
|
||||||
|
| `list_dir` | `path` | no |
|
||||||
|
| `write_file` | `path`, `content` | yes |
|
||||||
|
| `edit_file` | `path`, `old_text`, `new_text` | yes |
|
||||||
|
| `bash` | `command`, `cwd?` | yes |
|
||||||
|
|
||||||
|
### Path handling
|
||||||
|
|
||||||
|
`~`, `~/`, `$HOME`, and `$HOME/` are expanded server-side before
|
||||||
|
the path reaches ACP or local fs. Lets the model emit
|
||||||
|
`~/git/repo/file.rs` and have it Just Work.
|
||||||
|
|
||||||
|
`read_file` first tries the editor's filesystem (ACP's
|
||||||
|
`fs/read_text_file` — respects open buffers, workspace overlays,
|
||||||
|
etc.). If that fails — typically because the path is outside Zed's
|
||||||
|
workspace boundary — it falls back to `std::fs::read_to_string`.
|
||||||
|
This lets the agent pull in shared material like
|
||||||
|
`~/git/architecture/generic.md` from a different project's
|
||||||
|
session.
|
||||||
|
|
||||||
|
The fallback is logged at warn level so you can see when it kicks
|
||||||
|
in.
|
||||||
|
|
||||||
|
### Tool dispatch
|
||||||
|
|
||||||
|
Tool descriptions reach the model through a Qwen3 Hermes-format
|
||||||
|
`# Tools` block injected into the system prompt — cortex/neuron
|
||||||
|
pass the OpenAI `tools` request field through to the encoder
|
||||||
|
unread, so we work the model into emitting `<tool_call>{json}</tool_call>`
|
||||||
|
markers it then parses out of the content stream. This applies to
|
||||||
|
the helexa wire format; OpenAI / Anthropic endpoints with native
|
||||||
|
tool support would use their own paths once they're wired in.
|
||||||
|
|
||||||
|
The parser is tolerant: malformed JSON (trailing braces, missing
|
||||||
|
`name`, name nested in `arguments`) gets a repair pass; if that
|
||||||
|
fails the call surfaces as a "Malformed tool call" card in Zed and
|
||||||
|
the model gets a synthetic error result so it can self-correct.
|
||||||
|
|
||||||
|
## Session resume
|
||||||
|
|
||||||
|
helexa-acp persists every session to
|
||||||
|
`$XDG_DATA_HOME/helexa-acp/sessions/<id>.json`. Zed's `session/list`
|
||||||
|
RPC asks helexa-acp to enumerate them on workspace open;
|
||||||
|
`session/load` rehydrates and replays the transcript as
|
||||||
|
`session/update` notifications so the agent panel renders the
|
||||||
|
prior conversation.
|
||||||
|
|
||||||
|
Behaviour:
|
||||||
|
|
||||||
|
- Persisted per-round, so a mid-turn agent stall (long bash, wedged
|
||||||
|
ACP roundtrip) doesn't lose earlier rounds.
|
||||||
|
- Survives editor restart and the helexa-acp binary upgrading
|
||||||
|
between versions.
|
||||||
|
- Project-scoped: only sessions whose `cwd` matches the workspace
|
||||||
|
are listed.
|
||||||
|
|
||||||
|
To wipe history: `rm -rf $XDG_DATA_HOME/helexa-acp/sessions/`.
|
||||||
|
|
||||||
|
## Context compaction
|
||||||
|
|
||||||
|
When an endpoint sets `context_window`, helexa-acp projects the
|
||||||
|
rolling history into a token budget before each request — old
|
||||||
|
`ToolResult` content (read_file payloads are the worst offenders)
|
||||||
|
gets elided to one-line markers, preserving `tool_call_id` pairing
|
||||||
|
so the wire schema stays valid.
|
||||||
|
|
||||||
|
System prompts, user turns, and the most recent ~4 messages are
|
||||||
|
never elided. The full history stays on disk; compaction is a
|
||||||
|
per-request projection, not a destructive edit.
|
||||||
|
|
||||||
|
Set `context_window = 32768` for a 32 K Qwen3, `131072` for a
|
||||||
|
modern Claude, etc. With `max_tokens` also set, the budget is
|
||||||
|
`context_window - max_tokens - 512_safety`.
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### "default endpoint 'helexa' has no usable provider — check config"
|
||||||
|
|
||||||
|
The named default endpoint failed to construct. Usually:
|
||||||
|
|
||||||
|
- `api_key_env` references a variable that isn't set in the env
|
||||||
|
Zed launched helexa-acp with.
|
||||||
|
- The TOML's `wire_api` is misspelled (only `openai-chat`,
|
||||||
|
`openai-responses`, `anthropic-messages` are accepted).
|
||||||
|
|
||||||
|
Test by running `helexa-acp` directly from a shell — startup
|
||||||
|
errors land on stderr.
|
||||||
|
|
||||||
|
### Model dropdown is empty
|
||||||
|
|
||||||
|
Each provider's `list_models` failed at startup. Look at
|
||||||
|
`HELEXA_ACP_LOG_FILE` for "list_models failed; this endpoint's
|
||||||
|
models won't appear in the picker". Likely the endpoint URL is
|
||||||
|
wrong, the API key is invalid, or the upstream `/v1/models`
|
||||||
|
endpoint isn't responding.
|
||||||
|
|
||||||
|
The agent still works against `default_model` even when the
|
||||||
|
dropdown is empty — list-models is for picking, not routing.
|
||||||
|
|
||||||
|
### "prompt_too_long" / agent stalls mid-conversation
|
||||||
|
|
||||||
|
You hit the model's context window. Set `context_window` on the
|
||||||
|
endpoint and helexa-acp will compact before sending. The log line
|
||||||
|
`context compaction applied` confirms it's running; if it fires
|
||||||
|
but the upstream still rejects, the compaction heuristic
|
||||||
|
under-counted and the budget needs tuning down.
|
||||||
|
|
||||||
|
### Reading files outside the workspace returns "not found"
|
||||||
|
|
||||||
|
Zed's `fs/read_text_file` is workspace-scoped. helexa-acp falls
|
||||||
|
back to local `std::fs` automatically when that fails — look for
|
||||||
|
`fs/read_text_file failed; falling back to local std::fs` in the
|
||||||
|
log. If even local read fails, the file genuinely doesn't exist
|
||||||
|
or the user process lacks permissions.
|
||||||
|
|
||||||
|
### Tool calls render as text instead of structured cards
|
||||||
|
|
||||||
|
The model is emitting `<tool_call>` markers that the parser can't
|
||||||
|
decode. Two common causes:
|
||||||
|
|
||||||
|
1. The system prompt isn't reaching the model (cortex/neuron's
|
||||||
|
tool-block injection didn't fire). Confirm with
|
||||||
|
`RUST_LOG=helexa_acp=debug` and look at the outgoing
|
||||||
|
`POST /chat/completions` body.
|
||||||
|
2. The model itself is too small / undertrained to follow the
|
||||||
|
Hermes format reliably. helexa-acp has shape-based name
|
||||||
|
inference and JSON repair, but there's a floor below which
|
||||||
|
nothing helps.
|
||||||
|
|
||||||
|
### Plan-mode writes refused even inside the plan dir
|
||||||
|
|
||||||
|
The path comparison is byte-for-byte. If the model emits a path
|
||||||
|
with `~` and the plan_dir has the expanded form, expansion runs
|
||||||
|
*before* the comparison — but resolved-vs-symlinked-path
|
||||||
|
mismatches can still bite. The error message names the attempted
|
||||||
|
path and the expected prefix so you can compare directly.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
Source layout under `crates/helexa-acp/src/`:
|
||||||
|
|
||||||
|
| File | Responsibility |
|
||||||
|
|------|----------------|
|
||||||
|
| `main.rs` | tokio + Stdio transport. Builds providers, hands off to `agent::Agent` |
|
||||||
|
| `config.rs` | TOML + env-fallback config, endpoint resolver |
|
||||||
|
| `agent.rs` | ACP handlers (initialize, session/new, session/prompt, session/cancel, session/set_mode, session/set_model, session/load, session/list), prompt loop with tool-call recursion |
|
||||||
|
| `session.rs` | Per-session state map (Arc<RwLock<HashMap<…>>>) |
|
||||||
|
| `store.rs` | On-disk session persistence, plan-dir resolution |
|
||||||
|
| `prompt.rs` | System-prompt assembly, plan-mode addendum |
|
||||||
|
| `tools.rs` | Tool schemas + shape-based name inference |
|
||||||
|
| `tool_runner.rs` | Dispatch a single tool call through ACP client RPCs; permission gate |
|
||||||
|
| `qwen3.rs` | Qwen3 Hermes tool-format parser (`<tool_call>` / `<think>` markers) |
|
||||||
|
| `compaction.rs` | Token-budget compaction for the rolling history |
|
||||||
|
| `path_util.rs` | `~` / `$HOME` expansion shared across every path-taking tool |
|
||||||
|
| `provider/openai_chat.rs` | OpenAI chat completions provider |
|
||||||
|
| `provider/openai_responses.rs` | OpenAI Responses API provider |
|
||||||
|
| `provider/anthropic_messages.rs` | Anthropic Messages API provider |
|
||||||
|
|
||||||
|
### Adding a new wire format
|
||||||
|
|
||||||
|
1. New file under `src/provider/` implementing the `Provider`
|
||||||
|
trait (encoder + SSE decoder).
|
||||||
|
2. Add a `WireApi` variant in `config.rs`.
|
||||||
|
3. Wire it into `build_provider` in `main.rs`.
|
||||||
|
4. Done — every other module is wire-format-agnostic.
|
||||||
|
|
||||||
|
### Concurrency
|
||||||
|
|
||||||
|
- `Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>` —
|
||||||
|
per-session mutex so concurrent requests across sessions don't
|
||||||
|
contend; the map's RwLock is read-mostly.
|
||||||
|
- Every tool call dispatched serially within a session (parallel
|
||||||
|
dispatch would require Zed to handle interleaved permission
|
||||||
|
prompts).
|
||||||
|
- Provider streams are back-pressured by the consumer (bounded
|
||||||
|
mpsc channels).
|
||||||
|
|
||||||
|
### Self-contained
|
||||||
|
|
||||||
|
The crate has no workspace-internal dependencies (no
|
||||||
|
`cortex-core`, no `cortex-gateway`). Migration to a dedicated
|
||||||
|
GitHub repo for cross-platform CI / cargo-dist binaries is
|
||||||
|
Cargo.toml-only.
|
||||||
|
|
||||||
|
## Status
|
||||||
|
|
||||||
|
- Stages 1–6 shipped: scaffold, agent loop, tools, modes, session
|
||||||
|
resume, image input, model picker, three wire formats.
|
||||||
|
- Stage 8 (RPM + multi-platform CI) tracked in the canonical plan;
|
||||||
|
Linux x86_64 RPM ships today via the cortex monorepo's Gitea
|
||||||
|
Actions.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Repository: https://git.lair.cafe/helexa/cortex (`crates/helexa-acp/`).
|
||||||
|
Issues / PRs welcome. The canonical staged plan is in
|
||||||
|
`~/.claude/plans/plan-the-per-device-worker-abstract-micali.md` on
|
||||||
|
the maintainer's machine; the substages 3a–3e and 6a/6b that the
|
||||||
|
canonical plan didn't anticipate are documented in commit messages.
|
||||||
|
|
||||||
|
CI: `cargo fmt --check --all`, `cargo clippy --workspace -- -D
|
||||||
|
warnings`, `cargo test --workspace` must all pass before merge.
|
||||||
1820
crates/helexa-acp/src/agent.rs
Normal file
1820
crates/helexa-acp/src/agent.rs
Normal file
File diff suppressed because it is too large
Load Diff
425
crates/helexa-acp/src/compaction.rs
Normal file
425
crates/helexa-acp/src/compaction.rs
Normal file
@@ -0,0 +1,425 @@
|
|||||||
|
//! Rolling-conversation compaction for small-context local models.
|
||||||
|
//!
|
||||||
|
//! The tool-call loop in [`crate::agent`] grows the message vec it
|
||||||
|
//! sends upstream every round. On a frontier model that's fine; on a
|
||||||
|
//! 32 K Qwen3 the first few `read_file` results can push the prompt
|
||||||
|
//! past the model's context window, at which point cortex/neuron
|
||||||
|
//! refuses with `prompt_too_long` and the whole turn dies. Long-form
|
||||||
|
//! local agents are unusable without something here.
|
||||||
|
//!
|
||||||
|
//! Strategy (intentionally simple — no LLM-summarization round-trip,
|
||||||
|
//! no tokenizer dependency):
|
||||||
|
//!
|
||||||
|
//! 1. **Protect** the things the model cannot reason without:
|
||||||
|
//! - The system prompt (idx 0).
|
||||||
|
//! - Every `Role::User` turn (the user's intent — irreplaceable).
|
||||||
|
//! - The last [`KEEP_TAIL`] messages (most recent rounds stay
|
||||||
|
//! verbatim so the model can keep working on what it just
|
||||||
|
//! observed).
|
||||||
|
//! 2. **Elide** older `Role::Assistant` prose and older `Role::Tool`
|
||||||
|
//! result content. The structure stays — `tool_call_id`s, tool
|
||||||
|
//! names, and argument JSON survive intact — so OpenAI's strict
|
||||||
|
//! `tool_calls` ↔ `tool` pairing schema remains satisfied. Only
|
||||||
|
//! the *payload* shrinks to a one-line marker.
|
||||||
|
//! 3. Walk oldest→newest, recomputing the budget after each elision.
|
||||||
|
//! Stop as soon as we fit; we don't compact more than necessary.
|
||||||
|
//! 4. If we still exceed budget after eliding everything we're
|
||||||
|
//! allowed to, return what we have. The upstream will surface a
|
||||||
|
//! `prompt_too_long` error and the user can intervene; that's
|
||||||
|
//! better than silently dropping content the model needs.
|
||||||
|
//!
|
||||||
|
//! Token estimation uses a `chars / 3.5` heuristic — conservative
|
||||||
|
//! (over-estimates tokens slightly) so we compact a touch early
|
||||||
|
//! rather than a touch late.
|
||||||
|
|
||||||
|
use crate::provider::{Message, MessageContent, MessagePart, Role};
|
||||||
|
|
||||||
|
/// Most-recent N messages that are never elided. Roughly "the
|
||||||
|
/// current tool round in flight" — assistant turn that called the
|
||||||
|
/// tools + each tool result + a bit of slack.
|
||||||
|
const KEEP_TAIL: usize = 4;
|
||||||
|
|
||||||
|
/// Below this content size we don't bother eliding — the savings
|
||||||
|
/// don't outweigh the loss of detail. Roughly 60–80 tokens.
|
||||||
|
const ELIDE_MIN_CHARS: usize = 256;
|
||||||
|
|
||||||
|
/// Roughly tokens-per-character for English + code mixed in. The
|
||||||
|
/// actual per-tokenizer ratio varies (GPT-4o ≈ 4 chars/token on
|
||||||
|
/// English prose, ≈ 3 chars/token on code-heavy text). We pick a
|
||||||
|
/// value on the conservative end so the budget check fires *before*
|
||||||
|
/// the upstream tokenizer says no.
|
||||||
|
const CHARS_PER_TOKEN: f32 = 3.5;
|
||||||
|
|
||||||
|
/// Per-message envelope overhead (role + JSON framing). Comes out
|
||||||
|
/// to a few tokens; tiny but it adds up across long histories.
|
||||||
|
const ENVELOPE_TOKENS: usize = 8;
|
||||||
|
|
||||||
|
/// Rough per-image token cost used by the budget estimator. Real
|
||||||
|
/// vision tokenizers vary widely (256–1024 tokens for typical
|
||||||
|
/// resolutions on Qwen3-VL, OpenAI's `low`/`high` detail toggles
|
||||||
|
/// pick between ~85 and ~1000+). 512 is a defensible middle that
|
||||||
|
/// keeps compaction from treating images as free.
|
||||||
|
const IMAGE_TOKENS_APPROX: usize = 512;
|
||||||
|
|
||||||
|
/// Stats reported back from [`compact_to_budget`] for the caller to
|
||||||
|
/// log. The numbers are estimates (see [`estimate_tokens`]), so
|
||||||
|
/// don't compare them to upstream-reported token counts as if they
|
||||||
|
/// were exact.
|
||||||
|
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||||
|
pub struct CompactionStats {
|
||||||
|
/// Estimated tokens in the input messages.
|
||||||
|
pub original_tokens: usize,
|
||||||
|
/// Estimated tokens after compaction. Equal to `original_tokens`
|
||||||
|
/// when no compaction was needed.
|
||||||
|
pub final_tokens: usize,
|
||||||
|
/// Number of messages whose content was elided. Zero is the
|
||||||
|
/// hot path (nothing to do).
|
||||||
|
pub elided_messages: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompactionStats {
|
||||||
|
fn unchanged(tokens: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
original_tokens: tokens,
|
||||||
|
final_tokens: tokens,
|
||||||
|
elided_messages: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Approximate token count for one message. Sums the textual
|
||||||
|
/// payload's chars, divides by [`CHARS_PER_TOKEN`], and adds an
|
||||||
|
/// envelope constant. Cheap (no allocation) so safe to call once per
|
||||||
|
/// message per round.
|
||||||
|
pub fn estimate_tokens(msg: &Message) -> usize {
|
||||||
|
let chars = match &msg.content {
|
||||||
|
MessageContent::Text { text } => text.len(),
|
||||||
|
MessageContent::MultiPart { parts } => parts
|
||||||
|
.iter()
|
||||||
|
.map(|p| match p {
|
||||||
|
MessagePart::Text { text } => text.len(),
|
||||||
|
// Each image is one block in the context window; the
|
||||||
|
// upstream tokenizer handles the real cost (and it
|
||||||
|
// varies wildly by model — Qwen3-VL uses ~256-1024
|
||||||
|
// tokens per image depending on size). Take a
|
||||||
|
// middle estimate so the budget tracker doesn't
|
||||||
|
// pretend images are free.
|
||||||
|
MessagePart::Image(_) => IMAGE_TOKENS_APPROX * CHARS_PER_TOKEN as usize,
|
||||||
|
})
|
||||||
|
.sum(),
|
||||||
|
MessageContent::ToolCalls { text, calls } => {
|
||||||
|
let txt = text.as_deref().map(|s| s.len()).unwrap_or(0);
|
||||||
|
let calls_size: usize = calls
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.name.len() + c.arguments.len() + c.id.len())
|
||||||
|
.sum();
|
||||||
|
txt + calls_size
|
||||||
|
}
|
||||||
|
MessageContent::ToolResult {
|
||||||
|
tool_call_id,
|
||||||
|
content,
|
||||||
|
} => tool_call_id.len() + content.len(),
|
||||||
|
};
|
||||||
|
((chars as f32 / CHARS_PER_TOKEN) as usize) + ENVELOPE_TOKENS
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sum of [`estimate_tokens`] across all messages.
|
||||||
|
pub fn total_tokens(messages: &[Message]) -> usize {
|
||||||
|
messages.iter().map(estimate_tokens).sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Project `messages` into a vec whose estimated token count fits in
|
||||||
|
/// `budget` tokens. Returns the projection plus stats about what
|
||||||
|
/// was done. When the input already fits, the projection is a clone
|
||||||
|
/// of the input and stats report zero elisions.
|
||||||
|
///
|
||||||
|
/// See module docs for the strategy and protected set.
|
||||||
|
pub fn compact_to_budget(messages: &[Message], budget: usize) -> (Vec<Message>, CompactionStats) {
|
||||||
|
let original = total_tokens(messages);
|
||||||
|
if original <= budget {
|
||||||
|
return (messages.to_vec(), CompactionStats::unchanged(original));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut out = messages.to_vec();
|
||||||
|
let len = out.len();
|
||||||
|
let tail_start = len.saturating_sub(KEEP_TAIL);
|
||||||
|
let mut elided = 0usize;
|
||||||
|
|
||||||
|
// Two passes. First pass: ToolResult contents (largest savings
|
||||||
|
// per elision — read_file payloads land here). Second pass: long
|
||||||
|
// Assistant prose. We don't interleave because eliding a long
|
||||||
|
// assistant turn before a really old read_file would do less
|
||||||
|
// good per elision; oldest-first ordering is enforced *within*
|
||||||
|
// each pass instead.
|
||||||
|
for pass in 0..2 {
|
||||||
|
for i in 1..tail_start {
|
||||||
|
if matches!(out[i].role, Role::User) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let target_pass_2 = matches!(
|
||||||
|
&out[i].content,
|
||||||
|
MessageContent::Text { .. } | MessageContent::ToolCalls { .. }
|
||||||
|
);
|
||||||
|
let target_pass_1 = matches!(&out[i].content, MessageContent::ToolResult { .. });
|
||||||
|
let in_pass = (pass == 0 && target_pass_1) || (pass == 1 && target_pass_2);
|
||||||
|
if !in_pass {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if elide_in_place(&mut out[i]) {
|
||||||
|
elided += 1;
|
||||||
|
if total_tokens(&out) <= budget {
|
||||||
|
let final_tokens = total_tokens(&out);
|
||||||
|
return (
|
||||||
|
out,
|
||||||
|
CompactionStats {
|
||||||
|
original_tokens: original,
|
||||||
|
final_tokens,
|
||||||
|
elided_messages: elided,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_tokens = total_tokens(&out);
|
||||||
|
(
|
||||||
|
out,
|
||||||
|
CompactionStats {
|
||||||
|
original_tokens: original,
|
||||||
|
final_tokens,
|
||||||
|
elided_messages: elided,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shrink one message's payload while keeping its structural role
|
||||||
|
/// (so tool_call_id pairing survives). Returns `true` when the
|
||||||
|
/// message changed.
|
||||||
|
///
|
||||||
|
/// - `ToolResult.content` → `(elided: N bytes of tool result)`
|
||||||
|
/// - `ToolCalls.text` → `(elided: N bytes of assistant prose)`
|
||||||
|
/// - `Text` (assistant) → `(elided: N bytes of assistant prose)`
|
||||||
|
///
|
||||||
|
/// Already-tiny payloads are skipped — eliding a 50-byte string
|
||||||
|
/// would *grow* it once the marker is in place.
|
||||||
|
fn elide_in_place(msg: &mut Message) -> bool {
|
||||||
|
match &mut msg.content {
|
||||||
|
MessageContent::ToolResult { content, .. } => {
|
||||||
|
if content.len() < ELIDE_MIN_CHARS {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*content = format!("(elided: {} bytes of tool result)", content.len());
|
||||||
|
true
|
||||||
|
}
|
||||||
|
MessageContent::ToolCalls { text, .. } => match text {
|
||||||
|
Some(t) if t.len() >= ELIDE_MIN_CHARS => {
|
||||||
|
*text = Some(format!("(elided: {} bytes of assistant prose)", t.len()));
|
||||||
|
true
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
},
|
||||||
|
MessageContent::Text { text } => {
|
||||||
|
if text.len() < ELIDE_MIN_CHARS {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*text = format!("(elided: {} bytes of assistant prose)", text.len());
|
||||||
|
true
|
||||||
|
}
|
||||||
|
MessageContent::MultiPart { .. } => {
|
||||||
|
// MultiPart messages today only exist as User turns,
|
||||||
|
// and User turns are protected by the role check in
|
||||||
|
// `compact_to_budget` — so this branch is unreachable
|
||||||
|
// for current call sites. Returning false keeps the
|
||||||
|
// unreachable path benign if a future stage starts
|
||||||
|
// emitting MultiPart on other roles.
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::provider::ToolCall;
|
||||||
|
|
||||||
|
fn sys(text: &str) -> Message {
|
||||||
|
Message {
|
||||||
|
role: Role::System,
|
||||||
|
content: MessageContent::Text { text: text.into() },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn user(text: &str) -> Message {
|
||||||
|
Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: MessageContent::Text { text: text.into() },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn assistant_text(text: &str) -> Message {
|
||||||
|
Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: MessageContent::Text { text: text.into() },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn assistant_calls(text: Option<&str>, name: &str, args: &str, id: &str) -> Message {
|
||||||
|
Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: MessageContent::ToolCalls {
|
||||||
|
text: text.map(|s| s.to_string()),
|
||||||
|
calls: vec![ToolCall {
|
||||||
|
id: id.into(),
|
||||||
|
name: name.into(),
|
||||||
|
arguments: args.into(),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn tool_result(id: &str, body: &str) -> Message {
|
||||||
|
Message {
|
||||||
|
role: Role::Tool,
|
||||||
|
content: MessageContent::ToolResult {
|
||||||
|
tool_call_id: id.into(),
|
||||||
|
content: body.into(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn under_budget_is_a_no_op_clone() {
|
||||||
|
let msgs = vec![sys("you are an agent"), user("hi"), assistant_text("hello")];
|
||||||
|
let (out, stats) = compact_to_budget(&msgs, 10_000);
|
||||||
|
assert_eq!(stats.elided_messages, 0);
|
||||||
|
assert_eq!(stats.original_tokens, stats.final_tokens);
|
||||||
|
assert_eq!(out.len(), msgs.len());
|
||||||
|
// Strings unchanged.
|
||||||
|
match &out[2].content {
|
||||||
|
MessageContent::Text { text } => assert_eq!(text, "hello"),
|
||||||
|
other => panic!("expected Text, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn elides_old_tool_result_before_old_assistant_prose() {
|
||||||
|
// History: sys, user, assistant_calls, big_tool_result,
|
||||||
|
// assistant_with_big_text, user, assistant_calls,
|
||||||
|
// small_tool_result.
|
||||||
|
// KEEP_TAIL=4 protects the last four; the big tool result
|
||||||
|
// sits in the prunable range and should go first because
|
||||||
|
// pass 0 (tool results) runs before pass 1 (prose).
|
||||||
|
let big_result = "X".repeat(4096);
|
||||||
|
let big_prose = "Y".repeat(2048);
|
||||||
|
let msgs = vec![
|
||||||
|
sys("preamble"),
|
||||||
|
user("first ask"),
|
||||||
|
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "c0"),
|
||||||
|
tool_result("c0", &big_result),
|
||||||
|
assistant_text(&big_prose),
|
||||||
|
user("follow up"),
|
||||||
|
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "c1"),
|
||||||
|
tool_result("c1", "short result body"),
|
||||||
|
];
|
||||||
|
let before = total_tokens(&msgs);
|
||||||
|
// Force compaction by setting budget well below current.
|
||||||
|
let budget = before / 2;
|
||||||
|
let (out, stats) = compact_to_budget(&msgs, budget);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
stats.elided_messages >= 1,
|
||||||
|
"expected at least one elision, got {stats:?}"
|
||||||
|
);
|
||||||
|
// The big tool result must be elided (oldest fat target).
|
||||||
|
match &out[3].content {
|
||||||
|
MessageContent::ToolResult { content, .. } => {
|
||||||
|
assert!(
|
||||||
|
content.starts_with("(elided:"),
|
||||||
|
"tool result not elided: {content:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
other => panic!("expected ToolResult, got {other:?}"),
|
||||||
|
}
|
||||||
|
// Last four messages must be untouched.
|
||||||
|
assert!(matches!(
|
||||||
|
&out[out.len() - 1].content,
|
||||||
|
MessageContent::ToolResult { content, .. } if content == "short result body"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn never_elides_system_or_user_turns() {
|
||||||
|
let big_user = "U".repeat(8192);
|
||||||
|
let msgs = vec![sys("preamble"), user(&big_user), assistant_text("ok")];
|
||||||
|
let budget = 10; // way below — forces all possible elision
|
||||||
|
let (out, _stats) = compact_to_budget(&msgs, budget);
|
||||||
|
// System unchanged.
|
||||||
|
match &out[0].content {
|
||||||
|
MessageContent::Text { text } => assert_eq!(text, "preamble"),
|
||||||
|
other => panic!("expected Text, got {other:?}"),
|
||||||
|
}
|
||||||
|
// User unchanged even though it's huge.
|
||||||
|
match &out[1].content {
|
||||||
|
MessageContent::Text { text } => assert_eq!(text.len(), big_user.len()),
|
||||||
|
other => panic!("expected Text, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preserves_tool_call_id_pairing_after_elision() {
|
||||||
|
// OpenAI strict mode rejects a tool-result whose tool_call_id
|
||||||
|
// doesn't match a preceding assistant tool_call. Elision
|
||||||
|
// must not break that linkage.
|
||||||
|
let big = "Z".repeat(4096);
|
||||||
|
let msgs = vec![
|
||||||
|
sys("preamble"),
|
||||||
|
user("first"),
|
||||||
|
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "call_42"),
|
||||||
|
tool_result("call_42", &big),
|
||||||
|
// Tail messages.
|
||||||
|
user("next"),
|
||||||
|
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "call_43"),
|
||||||
|
tool_result("call_43", "ok"),
|
||||||
|
assistant_text("done"),
|
||||||
|
];
|
||||||
|
let budget = total_tokens(&msgs) / 3;
|
||||||
|
let (out, _stats) = compact_to_budget(&msgs, budget);
|
||||||
|
// The assistant call and its result both carry call_42.
|
||||||
|
let call_id = match &out[2].content {
|
||||||
|
MessageContent::ToolCalls { calls, .. } => calls[0].id.clone(),
|
||||||
|
other => panic!("expected ToolCalls, got {other:?}"),
|
||||||
|
};
|
||||||
|
match &out[3].content {
|
||||||
|
MessageContent::ToolResult { tool_call_id, .. } => {
|
||||||
|
assert_eq!(tool_call_id, &call_id, "pairing broken");
|
||||||
|
}
|
||||||
|
other => panic!("expected ToolResult, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn estimate_tokens_grows_with_content() {
|
||||||
|
let small = sys("hi");
|
||||||
|
let large = sys(&"x".repeat(10_000));
|
||||||
|
assert!(estimate_tokens(&large) > estimate_tokens(&small) * 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn elide_in_place_skips_short_content() {
|
||||||
|
let mut m = tool_result("c0", "tiny");
|
||||||
|
assert!(!elide_in_place(&mut m));
|
||||||
|
match m.content {
|
||||||
|
MessageContent::ToolResult { content, .. } => assert_eq!(content, "tiny"),
|
||||||
|
other => panic!("expected ToolResult, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn returns_best_effort_when_budget_unmeetable() {
|
||||||
|
// Single huge user message that cannot be elided. Budget 10.
|
||||||
|
// We don't error — we return what we have and let upstream
|
||||||
|
// refuse the prompt with its own error.
|
||||||
|
let big_user = "U".repeat(100_000);
|
||||||
|
let msgs = vec![sys("preamble"), user(&big_user)];
|
||||||
|
let (out, stats) = compact_to_budget(&msgs, 10);
|
||||||
|
assert_eq!(out.len(), msgs.len());
|
||||||
|
assert!(stats.final_tokens > 10, "still over budget by design");
|
||||||
|
}
|
||||||
|
}
|
||||||
424
crates/helexa-acp/src/config.rs
Normal file
424
crates/helexa-acp/src/config.rs
Normal file
@@ -0,0 +1,424 @@
|
|||||||
|
//! Configuration for the helexa-acp bridge.
|
||||||
|
//!
|
||||||
|
//! Loaded from `$XDG_CONFIG_HOME/helexa-acp/config.toml` (or
|
||||||
|
//! `~/.config/helexa-acp/config.toml` as a fallback). If no config file
|
||||||
|
//! exists, falls back to building a single anonymous endpoint from env
|
||||||
|
//! vars — that keeps "just point at one cortex" frictionless without
|
||||||
|
//! requiring a config file on disk.
|
||||||
|
//!
|
||||||
|
//! The design goal is "the missing ACP binary for users with multiple
|
||||||
|
//! API endpoints (possibly on a private LAN, possibly mixing wire
|
||||||
|
//! types)". Hence: every endpoint is named, has its own wire API, and
|
||||||
|
//! has its own default model. The agent's selected model id can be
|
||||||
|
//! prefixed `endpoint:model` to route across endpoints; a bare
|
||||||
|
//! `model` falls through to the configured `default_endpoint`.
|
||||||
|
//!
|
||||||
|
//! ### Example TOML
|
||||||
|
//!
|
||||||
|
//! ```toml
|
||||||
|
//! default_endpoint = "helexa"
|
||||||
|
//!
|
||||||
|
//! [[endpoints]]
|
||||||
|
//! name = "helexa"
|
||||||
|
//! base_url = "http://hanzalova.internal:31313/v1"
|
||||||
|
//! wire_api = "openai-chat"
|
||||||
|
//! default_model = "helexa/large"
|
||||||
|
//!
|
||||||
|
//! [[endpoints]]
|
||||||
|
//! name = "openrouter"
|
||||||
|
//! base_url = "https://openrouter.ai/api/v1"
|
||||||
|
//! wire_api = "openai-chat"
|
||||||
|
//! api_key_env = "OPENROUTER_API_KEY"
|
||||||
|
//! default_model = "anthropic/claude-opus-4"
|
||||||
|
//!
|
||||||
|
//! [[endpoints]]
|
||||||
|
//! name = "lmstudio"
|
||||||
|
//! base_url = "http://localhost:1234/v1"
|
||||||
|
//! wire_api = "openai-chat"
|
||||||
|
//! default_model = "auto"
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use anyhow::{Context, anyhow};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
const DEFAULT_BASE_URL: &str = "http://hanzalova.internal:31313/v1";
|
||||||
|
const DEFAULT_MODEL: &str = "helexa/large";
|
||||||
|
const DEFAULT_ENDPOINT_NAME: &str = "default";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
/// Name of the endpoint used when a request doesn't pick one
|
||||||
|
/// explicitly. Must reference an entry in `endpoints`. Defaults to
|
||||||
|
/// the first endpoint declared if unset.
|
||||||
|
#[serde(default)]
|
||||||
|
pub default_endpoint: Option<String>,
|
||||||
|
/// Per-endpoint configuration. At least one entry is required.
|
||||||
|
#[serde(default)]
|
||||||
|
pub endpoints: Vec<EndpointConfig>,
|
||||||
|
/// Optional path to a system-prompt file. When unset, the built-in
|
||||||
|
/// default prompt from `prompt.rs` is used.
|
||||||
|
#[serde(default)]
|
||||||
|
pub system_prompt_path: Option<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct EndpointConfig {
|
||||||
|
/// Short identifier used in `endpoint:model` routing and in logs.
|
||||||
|
pub name: String,
|
||||||
|
/// Base URL of the OpenAI-compatible API. Must include the `/v1`
|
||||||
|
/// (or equivalent) suffix — paths like `chat/completions` and
|
||||||
|
/// `models` are joined onto this.
|
||||||
|
pub base_url: Url,
|
||||||
|
/// Wire protocol the endpoint speaks. Phase 1 supports
|
||||||
|
/// [`WireApi::OpenAiChat`] only; `openai-responses` and
|
||||||
|
/// `anthropic-messages` land later behind their own providers.
|
||||||
|
#[serde(default)]
|
||||||
|
pub wire_api: WireApi,
|
||||||
|
/// Model to use when the client hasn't picked one via
|
||||||
|
/// `session/set_model`.
|
||||||
|
#[serde(default)]
|
||||||
|
pub default_model: Option<String>,
|
||||||
|
/// Static API key to send as `Authorization: Bearer …`. Prefer
|
||||||
|
/// `api_key_env` for anything sensitive — keys in plain TOML are a
|
||||||
|
/// liability.
|
||||||
|
#[serde(default)]
|
||||||
|
pub api_key: Option<String>,
|
||||||
|
/// Env var name to read for the API key. Resolved at startup so a
|
||||||
|
/// missing env var yields a clear error rather than silent
|
||||||
|
/// unauthenticated calls.
|
||||||
|
#[serde(default)]
|
||||||
|
pub api_key_env: Option<String>,
|
||||||
|
/// Cap on the model's output tokens per turn. `None` lets the
|
||||||
|
/// upstream pick its own default (cortex/neuron's default is
|
||||||
|
/// often small enough to trip Zed's "Output Limit Reached" on
|
||||||
|
/// long responses). Set to e.g. `32768` to let the model
|
||||||
|
/// produce longer turns. Goes into the OpenAI `max_tokens`
|
||||||
|
/// request field.
|
||||||
|
#[serde(default)]
|
||||||
|
pub max_tokens: Option<u64>,
|
||||||
|
/// Model context window in tokens (prompt + response). When set,
|
||||||
|
/// the agent compacts conversation history before each completion
|
||||||
|
/// so the prompt fits within `context_window - max_tokens - safety`
|
||||||
|
/// tokens — long sessions on small-context local models (Qwen3 at
|
||||||
|
/// 32 K) survive past the first few tool-call rounds rather than
|
||||||
|
/// dying with `prompt_too_long`. `None` disables compaction.
|
||||||
|
#[serde(default)]
|
||||||
|
pub context_window: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||||
|
pub enum WireApi {
|
||||||
|
/// `POST {base}/chat/completions` returning OpenAI-format SSE.
|
||||||
|
/// Compatible with cortex, LM Studio, Ollama (compat mode),
|
||||||
|
/// OpenRouter, OpenAI itself.
|
||||||
|
#[default]
|
||||||
|
#[serde(rename = "openai-chat")]
|
||||||
|
OpenAiChat,
|
||||||
|
/// `POST {base}/responses` — OpenAI's newer Responses API. Not
|
||||||
|
/// implemented yet; the variant is reserved so endpoint configs
|
||||||
|
/// can be authored ahead of provider support.
|
||||||
|
#[serde(rename = "openai-responses")]
|
||||||
|
OpenAiResponses,
|
||||||
|
/// `POST {base}/messages` — Anthropic format. Reserved.
|
||||||
|
#[serde(rename = "anthropic-messages")]
|
||||||
|
AnthropicMessages,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EndpointConfig {
|
||||||
|
/// Resolve the API key from `api_key` (literal) or `api_key_env`
|
||||||
|
/// (env-var lookup). Returns `Ok(None)` when neither is set;
|
||||||
|
/// `Err` when `api_key_env` references a missing variable.
|
||||||
|
pub fn resolve_api_key(&self) -> anyhow::Result<Option<String>> {
|
||||||
|
if let Some(literal) = &self.api_key {
|
||||||
|
return Ok(Some(literal.clone()));
|
||||||
|
}
|
||||||
|
if let Some(var) = &self.api_key_env {
|
||||||
|
return Ok(Some(std::env::var(var).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"endpoint '{}' references missing env var {}",
|
||||||
|
self.name, var
|
||||||
|
)
|
||||||
|
})?));
|
||||||
|
}
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `{base_url}/chat/completions`.
|
||||||
|
pub fn chat_completions_url(&self) -> Url {
|
||||||
|
join_segments(&self.base_url, &["chat", "completions"])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `{base_url}/responses` — OpenAI Responses API endpoint.
|
||||||
|
pub fn responses_url(&self) -> Url {
|
||||||
|
join_segments(&self.base_url, &["responses"])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `{base_url}/models`. Called from `Provider::list_models`, which
|
||||||
|
/// Stage 4 wires into the model-picker dropdown; until then it's
|
||||||
|
/// reachable code with no in-tree callers.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn models_url(&self) -> Url {
|
||||||
|
join_segments(&self.base_url, &["models"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
/// Load from TOML at the standard config path, or build from env
|
||||||
|
/// vars if no file exists. Env-fallback yields a single endpoint
|
||||||
|
/// named `"default"`.
|
||||||
|
pub fn load() -> anyhow::Result<Self> {
|
||||||
|
let path = config_path();
|
||||||
|
if let Some(path) = &path
|
||||||
|
&& path.exists()
|
||||||
|
{
|
||||||
|
return Self::from_file(path);
|
||||||
|
}
|
||||||
|
Self::from_env()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Single-endpoint config constructed from `HELEXA_ACP_BASE_URL`,
|
||||||
|
/// `HELEXA_ACP_MODEL`, `HELEXA_ACP_API_KEY`,
|
||||||
|
/// `HELEXA_ACP_SYSTEM_PROMPT_PATH`, `HELEXA_ACP_MAX_TOKENS`.
|
||||||
|
pub fn from_env() -> anyhow::Result<Self> {
|
||||||
|
let base_url = std::env::var("HELEXA_ACP_BASE_URL")
|
||||||
|
.ok()
|
||||||
|
.unwrap_or_else(|| DEFAULT_BASE_URL.into());
|
||||||
|
let base_url = Url::parse(&base_url)
|
||||||
|
.with_context(|| format!("HELEXA_ACP_BASE_URL is not a valid URL ({base_url})"))?;
|
||||||
|
let default_model = std::env::var("HELEXA_ACP_MODEL")
|
||||||
|
.ok()
|
||||||
|
.unwrap_or_else(|| DEFAULT_MODEL.into());
|
||||||
|
let api_key = std::env::var("HELEXA_ACP_API_KEY")
|
||||||
|
.ok()
|
||||||
|
.filter(|s| !s.is_empty());
|
||||||
|
let system_prompt_path = std::env::var("HELEXA_ACP_SYSTEM_PROMPT_PATH")
|
||||||
|
.ok()
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.map(PathBuf::from);
|
||||||
|
let max_tokens = std::env::var("HELEXA_ACP_MAX_TOKENS")
|
||||||
|
.ok()
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.map(|s| {
|
||||||
|
s.parse::<u64>().with_context(|| {
|
||||||
|
format!("HELEXA_ACP_MAX_TOKENS is not a positive integer ({s})")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
let context_window = std::env::var("HELEXA_ACP_CONTEXT_WINDOW")
|
||||||
|
.ok()
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.map(|s| {
|
||||||
|
s.parse::<usize>().with_context(|| {
|
||||||
|
format!("HELEXA_ACP_CONTEXT_WINDOW is not a positive integer ({s})")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
Ok(Self {
|
||||||
|
default_endpoint: Some(DEFAULT_ENDPOINT_NAME.into()),
|
||||||
|
endpoints: vec![EndpointConfig {
|
||||||
|
name: DEFAULT_ENDPOINT_NAME.into(),
|
||||||
|
base_url,
|
||||||
|
wire_api: WireApi::OpenAiChat,
|
||||||
|
default_model: Some(default_model),
|
||||||
|
api_key,
|
||||||
|
api_key_env: None,
|
||||||
|
max_tokens,
|
||||||
|
context_window,
|
||||||
|
}],
|
||||||
|
system_prompt_path,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_file(path: &Path) -> anyhow::Result<Self> {
|
||||||
|
let text = std::fs::read_to_string(path)
|
||||||
|
.with_context(|| format!("read config {}", path.display()))?;
|
||||||
|
let mut cfg: Self =
|
||||||
|
toml::from_str(&text).with_context(|| format!("parse config {}", path.display()))?;
|
||||||
|
cfg.validate()?;
|
||||||
|
Ok(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate(&mut self) -> anyhow::Result<()> {
|
||||||
|
if self.endpoints.is_empty() {
|
||||||
|
return Err(anyhow!("config has no [[endpoints]] entries"));
|
||||||
|
}
|
||||||
|
for (i, ep) in self.endpoints.iter().enumerate() {
|
||||||
|
if ep.name.is_empty() {
|
||||||
|
return Err(anyhow!("endpoints[{i}] has empty name"));
|
||||||
|
}
|
||||||
|
if ep.name.contains(':') {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"endpoints[{i}].name '{}' contains ':' which would clash \
|
||||||
|
with the endpoint:model selector syntax",
|
||||||
|
ep.name
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Pick a default endpoint if none was named.
|
||||||
|
if self.default_endpoint.is_none() {
|
||||||
|
self.default_endpoint = Some(self.endpoints[0].name.clone());
|
||||||
|
}
|
||||||
|
let default_name = self.default_endpoint.as_deref().unwrap();
|
||||||
|
if !self.endpoints.iter().any(|e| e.name == default_name) {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"default_endpoint '{default_name}' is not declared in [[endpoints]]"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up an endpoint by name. Returns `None` if not configured.
|
||||||
|
pub fn endpoint(&self, name: &str) -> Option<&EndpointConfig> {
|
||||||
|
self.endpoints.iter().find(|e| e.name == name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The default endpoint (guaranteed to exist after `validate`).
|
||||||
|
pub fn default_endpoint(&self) -> &EndpointConfig {
|
||||||
|
let name = self
|
||||||
|
.default_endpoint
|
||||||
|
.as_deref()
|
||||||
|
.expect("default_endpoint set by validate");
|
||||||
|
self.endpoint(name)
|
||||||
|
.expect("default_endpoint resolves after validate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse an ACP-side `model` field into (endpoint name, raw model id).
|
||||||
|
///
|
||||||
|
/// `helexa:helexa/large` → (`Some("helexa")`, `"helexa/large"`).
|
||||||
|
/// `helexa/large` → (`None`, `"helexa/large"`).
|
||||||
|
///
|
||||||
|
/// The split happens at the FIRST colon. Model ids commonly contain
|
||||||
|
/// `/` (HuggingFace style) but rarely `:`; if a model id ever does, the
|
||||||
|
/// user can quote-prefix with the default endpoint name.
|
||||||
|
pub fn parse_model_selector(input: &str) -> (Option<&str>, &str) {
|
||||||
|
match input.split_once(':') {
|
||||||
|
Some((endpoint, model)) if !endpoint.is_empty() && !model.is_empty() => {
|
||||||
|
(Some(endpoint), model)
|
||||||
|
}
|
||||||
|
_ => (None, input),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config_path() -> Option<PathBuf> {
|
||||||
|
if let Ok(override_path) = std::env::var("HELEXA_ACP_CONFIG_PATH") {
|
||||||
|
return Some(PathBuf::from(override_path));
|
||||||
|
}
|
||||||
|
let xdg = std::env::var("XDG_CONFIG_HOME")
|
||||||
|
.ok()
|
||||||
|
.filter(|s| !s.is_empty());
|
||||||
|
let base = xdg.map(PathBuf::from).or_else(|| {
|
||||||
|
std::env::var("HOME")
|
||||||
|
.ok()
|
||||||
|
.map(|h| PathBuf::from(h).join(".config"))
|
||||||
|
})?;
|
||||||
|
Some(base.join("helexa-acp").join("config.toml"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn join_segments(base: &Url, segments: &[&str]) -> Url {
|
||||||
|
let mut out = base.clone();
|
||||||
|
if let Ok(mut path) = out.path_segments_mut() {
|
||||||
|
path.pop_if_empty().extend(segments.iter().copied());
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn url_join_handles_trailing_slash() {
|
||||||
|
let ep = EndpointConfig {
|
||||||
|
name: "x".into(),
|
||||||
|
base_url: Url::parse("http://h.internal:31313/v1").unwrap(),
|
||||||
|
wire_api: WireApi::OpenAiChat,
|
||||||
|
default_model: None,
|
||||||
|
api_key: None,
|
||||||
|
api_key_env: None,
|
||||||
|
max_tokens: None,
|
||||||
|
context_window: None,
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
ep.chat_completions_url().as_str(),
|
||||||
|
"http://h.internal:31313/v1/chat/completions"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ep.models_url().as_str(),
|
||||||
|
"http://h.internal:31313/v1/models"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_model_selector() {
|
||||||
|
assert_eq!(
|
||||||
|
parse_model_selector("helexa:helexa/large"),
|
||||||
|
(Some("helexa"), "helexa/large")
|
||||||
|
);
|
||||||
|
assert_eq!(parse_model_selector("helexa/large"), (None, "helexa/large"));
|
||||||
|
assert_eq!(parse_model_selector("gpt-5"), (None, "gpt-5"));
|
||||||
|
// Edge case: a leading colon → no endpoint.
|
||||||
|
assert_eq!(parse_model_selector(":gpt-5"), (None, ":gpt-5"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn env_fallback_builds_single_endpoint() {
|
||||||
|
// Don't actually set env vars (would race with other tests);
|
||||||
|
// just confirm the default path constructs cleanly.
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("HELEXA_ACP_BASE_URL");
|
||||||
|
std::env::remove_var("HELEXA_ACP_MODEL");
|
||||||
|
std::env::remove_var("HELEXA_ACP_API_KEY");
|
||||||
|
}
|
||||||
|
let cfg = Config::from_env().unwrap();
|
||||||
|
assert_eq!(cfg.endpoints.len(), 1);
|
||||||
|
assert_eq!(cfg.endpoints[0].name, "default");
|
||||||
|
assert_eq!(cfg.endpoints[0].base_url.as_str(), DEFAULT_BASE_URL);
|
||||||
|
assert_eq!(
|
||||||
|
cfg.endpoints[0].default_model.as_deref(),
|
||||||
|
Some(DEFAULT_MODEL)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn toml_parses_multi_endpoint() {
|
||||||
|
let toml_text = r#"
|
||||||
|
default_endpoint = "helexa"
|
||||||
|
|
||||||
|
[[endpoints]]
|
||||||
|
name = "helexa"
|
||||||
|
base_url = "http://hanzalova.internal:31313/v1"
|
||||||
|
default_model = "helexa/large"
|
||||||
|
|
||||||
|
[[endpoints]]
|
||||||
|
name = "openrouter"
|
||||||
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
|
wire_api = "openai-chat"
|
||||||
|
api_key_env = "OPENROUTER_API_KEY"
|
||||||
|
default_model = "anthropic/claude-opus-4"
|
||||||
|
"#;
|
||||||
|
let mut cfg: Config = toml::from_str(toml_text).unwrap();
|
||||||
|
cfg.validate().unwrap();
|
||||||
|
assert_eq!(cfg.endpoints.len(), 2);
|
||||||
|
assert_eq!(cfg.default_endpoint().name, "helexa");
|
||||||
|
assert_eq!(cfg.endpoints[0].wire_api, WireApi::OpenAiChat);
|
||||||
|
assert_eq!(
|
||||||
|
cfg.endpoints[1].api_key_env.as_deref(),
|
||||||
|
Some("OPENROUTER_API_KEY")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_rejects_colon_in_endpoint_name() {
|
||||||
|
let toml_text = r#"
|
||||||
|
[[endpoints]]
|
||||||
|
name = "bad:name"
|
||||||
|
base_url = "http://x/v1"
|
||||||
|
"#;
|
||||||
|
let mut cfg: Config = toml::from_str(toml_text).unwrap();
|
||||||
|
let err = cfg.validate().unwrap_err();
|
||||||
|
assert!(format!("{err}").contains("clash"));
|
||||||
|
}
|
||||||
|
}
|
||||||
145
crates/helexa-acp/src/main.rs
Normal file
145
crates/helexa-acp/src/main.rs
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
//! helexa-acp — Agent Client Protocol bridge for multi-endpoint LLM
|
||||||
|
//! setups (helexa, LM Studio, Ollama, OpenRouter, OpenAI, Anthropic,
|
||||||
|
//! …) with a clean per-endpoint wire-format selector.
|
||||||
|
//!
|
||||||
|
//! Speaks ACP over stdio to an editor client (Zed today). Every
|
||||||
|
//! configured endpoint produces a wire-format-specific
|
||||||
|
//! [`provider::Provider`] implementation; the agent loop in
|
||||||
|
//! [`agent::Agent`] is provider-agnostic, so adding e.g. an Anthropic
|
||||||
|
//! /v1/messages provider doesn't touch `agent.rs`.
|
||||||
|
//!
|
||||||
|
//! Config: `$XDG_CONFIG_HOME/helexa-acp/config.toml` for the multi-
|
||||||
|
//! endpoint case; env vars (`HELEXA_ACP_BASE_URL`, etc.) for the
|
||||||
|
//! single-endpoint case when no config file exists.
|
||||||
|
|
||||||
|
use agent_client_protocol::{Result, Stdio};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
mod agent;
|
||||||
|
mod compaction;
|
||||||
|
mod config;
|
||||||
|
mod path_util;
|
||||||
|
mod prompt;
|
||||||
|
mod provider;
|
||||||
|
mod qwen3;
|
||||||
|
mod session;
|
||||||
|
mod store;
|
||||||
|
mod tool_runner;
|
||||||
|
mod tools;
|
||||||
|
|
||||||
|
use agent::Agent;
|
||||||
|
use config::{Config, EndpointConfig, WireApi};
|
||||||
|
use provider::{
|
||||||
|
Provider, anthropic_messages::AnthropicMessagesProvider, openai_chat::OpenAIChatProvider,
|
||||||
|
openai_responses::OpenAIResponsesProvider,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Set up tracing. Logs go to stderr by default — stdout is
|
||||||
|
/// reserved for the JSON-RPC stream. Setting `HELEXA_ACP_LOG_FILE`
|
||||||
|
/// to an absolute path appends logs to that file instead, which is
|
||||||
|
/// the practical way to capture debug output when the agent runs
|
||||||
|
/// under an editor (Zed, etc.) that doesn't surface stderr.
|
||||||
|
///
|
||||||
|
/// `RUST_LOG` still controls levels (e.g. `helexa_acp=debug`).
|
||||||
|
/// ANSI colours are auto-stripped when writing to a file so the log
|
||||||
|
/// is plain text.
|
||||||
|
fn init_tracing() {
|
||||||
|
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"));
|
||||||
|
|
||||||
|
let log_file = std::env::var("HELEXA_ACP_LOG_FILE")
|
||||||
|
.ok()
|
||||||
|
.filter(|s| !s.is_empty());
|
||||||
|
|
||||||
|
match log_file {
|
||||||
|
Some(path) => match std::fs::OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.append(true)
|
||||||
|
.open(&path)
|
||||||
|
{
|
||||||
|
Ok(file) => {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_writer(std::sync::Mutex::new(file))
|
||||||
|
.with_env_filter(env_filter)
|
||||||
|
.with_ansi(false)
|
||||||
|
.init();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Fall back to stderr and shout. We don't want a
|
||||||
|
// typo'd log path to silence the agent entirely.
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.with_env_filter(env_filter)
|
||||||
|
.init();
|
||||||
|
tracing::warn!(
|
||||||
|
path = %path,
|
||||||
|
error = %e,
|
||||||
|
"HELEXA_ACP_LOG_FILE could not be opened; using stderr"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.with_env_filter(env_filter)
|
||||||
|
.init();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a provider for `endpoint` according to its declared
|
||||||
|
/// `wire_api`. Future wire types (OpenAI Responses, Anthropic
|
||||||
|
/// /v1/messages, Ollama native) slot in here without changing the
|
||||||
|
/// caller.
|
||||||
|
fn build_provider(endpoint: EndpointConfig) -> anyhow::Result<Arc<dyn Provider>> {
|
||||||
|
match endpoint.wire_api {
|
||||||
|
WireApi::OpenAiChat => Ok(Arc::new(OpenAIChatProvider::new(endpoint)?)),
|
||||||
|
WireApi::OpenAiResponses => Ok(Arc::new(OpenAIResponsesProvider::new(endpoint)?)),
|
||||||
|
WireApi::AnthropicMessages => Ok(Arc::new(AnthropicMessagesProvider::new(endpoint)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
init_tracing();
|
||||||
|
|
||||||
|
let cfg = Config::load()
|
||||||
|
.map_err(|e| agent_client_protocol::util::internal_error(format!("config: {e:#}")))?;
|
||||||
|
tracing::info!(
|
||||||
|
endpoints = cfg.endpoints.len(),
|
||||||
|
default_endpoint = %cfg.default_endpoint().name,
|
||||||
|
default_model = ?cfg.default_endpoint().default_model,
|
||||||
|
"helexa-acp starting"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Build a provider for each configured endpoint up-front. Cheap —
|
||||||
|
// just sets up a reqwest::Client and resolves the API key — and
|
||||||
|
// surfaces config mistakes (missing API key env var, unsupported
|
||||||
|
// wire_api) before the editor even sends an initialize request.
|
||||||
|
let mut providers: Vec<Arc<dyn Provider>> = Vec::with_capacity(cfg.endpoints.len());
|
||||||
|
for endpoint in &cfg.endpoints {
|
||||||
|
match build_provider(endpoint.clone()) {
|
||||||
|
Ok(p) => {
|
||||||
|
tracing::info!(
|
||||||
|
endpoint = %endpoint.name,
|
||||||
|
base_url = %endpoint.base_url,
|
||||||
|
wire_api = ?endpoint.wire_api,
|
||||||
|
"registered provider"
|
||||||
|
);
|
||||||
|
providers.push(p);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
endpoint = %endpoint.name,
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
"skipping endpoint with invalid config"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let agent = Agent::new(&cfg, providers)
|
||||||
|
.await
|
||||||
|
.map_err(|e| agent_client_protocol::util::internal_error(format!("agent: {e:#}")))?;
|
||||||
|
agent.serve(Stdio::new()).await
|
||||||
|
}
|
||||||
192
crates/helexa-acp/src/path_util.rs
Normal file
192
crates/helexa-acp/src/path_util.rs
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
//! Path expansion shared across every tool that takes a path.
|
||||||
|
//!
|
||||||
|
//! Models often emit shell-style paths like `~/git/repo/file.rs` or
|
||||||
|
//! `$HOME/notes.md`. ACP's `fs/read_text_file` and friends — and our
|
||||||
|
//! own local `std::fs` reads — both want a real absolute path; the
|
||||||
|
//! `~` / `$HOME` forms reach them as literal strings and the open
|
||||||
|
//! fails. The tool schemas already document "absolute path" but in
|
||||||
|
//! practice the model slips up often enough that handling it
|
||||||
|
//! server-side is the difference between "works" and "the agent is
|
||||||
|
//! brittle".
|
||||||
|
//!
|
||||||
|
//! Scope is deliberately small:
|
||||||
|
//!
|
||||||
|
//! - `~` and `~/` (current user only — `~user` lookups would require
|
||||||
|
//! pulling in passwd parsing).
|
||||||
|
//! - `$HOME` and `$HOME/`.
|
||||||
|
//!
|
||||||
|
//! Any other shell variable (`$PWD`, `${HOME}`, …) passes through
|
||||||
|
//! unchanged. The shell already expands them inside `bash` tool
|
||||||
|
//! commands; for the file-tool argument fields, we deliberately
|
||||||
|
//! limit the set so the behaviour is predictable.
|
||||||
|
//!
|
||||||
|
//! Falls back to the input path verbatim when `HOME` is unset
|
||||||
|
//! (stripped-down container env). That preserves the "no surprise
|
||||||
|
//! mutations" rule — never invent a path the caller didn't ask for.
|
||||||
|
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
/// Process-global lock for tests that mutate `HOME`. Anyone in the
|
||||||
|
/// crate touching `HOME` must hold this for the duration of the
|
||||||
|
/// read-modify-restore window — otherwise concurrent `cargo test`
|
||||||
|
/// workers race and flake.
|
||||||
|
///
|
||||||
|
/// Only built into the test binaries. Production code never mutates
|
||||||
|
/// env vars.
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||||
|
|
||||||
|
/// Expand `~`, `~/`, `$HOME`, and `$HOME/` prefixes against the
|
||||||
|
/// current user's home directory. All other inputs pass through
|
||||||
|
/// unchanged.
|
||||||
|
///
|
||||||
|
/// Returns the input verbatim if `HOME` isn't set in the env.
|
||||||
|
pub fn expand_path(input: &Path) -> PathBuf {
|
||||||
|
let Some(s) = input.to_str() else {
|
||||||
|
return input.to_path_buf();
|
||||||
|
};
|
||||||
|
let Ok(home) = std::env::var("HOME") else {
|
||||||
|
return input.to_path_buf();
|
||||||
|
};
|
||||||
|
let home = PathBuf::from(home);
|
||||||
|
if s == "~" || s == "$HOME" {
|
||||||
|
return home;
|
||||||
|
}
|
||||||
|
if let Some(rest) = s.strip_prefix("~/") {
|
||||||
|
return home.join(rest);
|
||||||
|
}
|
||||||
|
if let Some(rest) = s.strip_prefix("$HOME/") {
|
||||||
|
return home.join(rest);
|
||||||
|
}
|
||||||
|
input.to_path_buf()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// Set HOME for the duration of the test. Tests using this run
|
||||||
|
/// serially under the crate-wide [`ENV_LOCK`] because env
|
||||||
|
/// mutation isn't thread-safe — `cargo test` parallel workers
|
||||||
|
/// would race without it.
|
||||||
|
fn with_home<F: FnOnce()>(home: &str, body: F) {
|
||||||
|
let _g = ENV_LOCK.lock().unwrap();
|
||||||
|
let prior = std::env::var("HOME").ok();
|
||||||
|
// SAFETY: tests touch process-global env. The mutex
|
||||||
|
// serialises access; sub-threads in other test modules
|
||||||
|
// touching HOME aren't expected (none in this crate).
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("HOME", home);
|
||||||
|
}
|
||||||
|
body();
|
||||||
|
unsafe {
|
||||||
|
match prior {
|
||||||
|
Some(p) => std::env::set_var("HOME", p),
|
||||||
|
None => std::env::remove_var("HOME"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn expands_tilde_slash() {
|
||||||
|
with_home("/home/me", || {
|
||||||
|
assert_eq!(
|
||||||
|
expand_path(Path::new("~/git/repo/file.rs")),
|
||||||
|
PathBuf::from("/home/me/git/repo/file.rs")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn expands_bare_tilde() {
|
||||||
|
with_home("/home/me", || {
|
||||||
|
assert_eq!(expand_path(Path::new("~")), PathBuf::from("/home/me"));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn expands_dollar_home_slash() {
|
||||||
|
with_home("/home/me", || {
|
||||||
|
assert_eq!(
|
||||||
|
expand_path(Path::new("$HOME/notes.md")),
|
||||||
|
PathBuf::from("/home/me/notes.md")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn expands_bare_dollar_home() {
|
||||||
|
with_home("/home/me", || {
|
||||||
|
assert_eq!(expand_path(Path::new("$HOME")), PathBuf::from("/home/me"));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn absolute_path_passes_through() {
|
||||||
|
with_home("/home/me", || {
|
||||||
|
assert_eq!(
|
||||||
|
expand_path(Path::new("/etc/hostname")),
|
||||||
|
PathBuf::from("/etc/hostname")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn relative_path_passes_through() {
|
||||||
|
with_home("/home/me", || {
|
||||||
|
assert_eq!(
|
||||||
|
expand_path(Path::new("src/main.rs")),
|
||||||
|
PathBuf::from("src/main.rs")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tilde_user_form_not_expanded() {
|
||||||
|
// ~other is shell sugar for /home/other and would require
|
||||||
|
// passwd parsing to resolve. Out of scope — pass it
|
||||||
|
// through and let the open fail with a clear error.
|
||||||
|
with_home("/home/me", || {
|
||||||
|
assert_eq!(
|
||||||
|
expand_path(Path::new("~other/x")),
|
||||||
|
PathBuf::from("~other/x")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_home_env_passes_through() {
|
||||||
|
// Share the same crate-wide lock as `with_home` — otherwise
|
||||||
|
// a parallel test setting HOME races this clear-and-assert
|
||||||
|
// window.
|
||||||
|
let _g = ENV_LOCK.lock().unwrap();
|
||||||
|
let prior = std::env::var("HOME").ok();
|
||||||
|
// SAFETY: serialised by LOCK above.
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("HOME");
|
||||||
|
}
|
||||||
|
assert_eq!(
|
||||||
|
expand_path(Path::new("~/git/repo")),
|
||||||
|
PathBuf::from("~/git/repo")
|
||||||
|
);
|
||||||
|
unsafe {
|
||||||
|
if let Some(p) = prior {
|
||||||
|
std::env::set_var("HOME", p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dollar_other_var_not_expanded() {
|
||||||
|
with_home("/home/me", || {
|
||||||
|
assert_eq!(
|
||||||
|
expand_path(Path::new("$PWD/file")),
|
||||||
|
PathBuf::from("$PWD/file")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
expand_path(Path::new("${HOME}/file")),
|
||||||
|
PathBuf::from("${HOME}/file")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
274
crates/helexa-acp/src/prompt.rs
Normal file
274
crates/helexa-acp/src/prompt.rs
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
//! System prompt assembly.
|
||||||
|
//!
|
||||||
|
//! The system message has two parts:
|
||||||
|
//!
|
||||||
|
//! 1. A short human-readable preamble (working directory, style
|
||||||
|
//! instructions). Either the built-in [`DEFAULT_PROMPT`] or a
|
||||||
|
//! user-supplied file at `HELEXA_ACP_SYSTEM_PROMPT_PATH` /
|
||||||
|
//! `system_prompt_path`. `{cwd}` is substituted in both.
|
||||||
|
//! 2. A `# Tools` block in Qwen3 Hermes format (see [`crate::qwen3`])
|
||||||
|
//! describing the available functions. This is what makes the
|
||||||
|
//! model actually call them — neuron/cortex don't honour the
|
||||||
|
//! OpenAI `tools` API field, so the tool list has to live in the
|
||||||
|
//! prompt itself.
|
||||||
|
|
||||||
|
use agent_client_protocol::schema::SessionModeId;
|
||||||
|
use anyhow::Context;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use crate::provider::ToolSpec;
|
||||||
|
use crate::qwen3;
|
||||||
|
use crate::session::MODE_PLAN;
|
||||||
|
|
||||||
|
const DEFAULT_PROMPT: &str = "\
|
||||||
|
You are helexa-acp, a coding assistant working inside an editor.
|
||||||
|
|
||||||
|
Working directory: {cwd}
|
||||||
|
|
||||||
|
Use the tools described below whenever the user's request involves
|
||||||
|
looking at or modifying files, or running commands. Do not ask the
|
||||||
|
user to paste file contents you could read yourself. All file paths
|
||||||
|
must be absolute. Writes and shell commands may prompt the user for
|
||||||
|
permission depending on the session mode.
|
||||||
|
|
||||||
|
Be concise; the user is reading your output in an editor pane.";
|
||||||
|
|
||||||
|
/// Build the system prompt for a session.
|
||||||
|
///
|
||||||
|
/// - `cwd`: session working directory (substituted for `{cwd}` in
|
||||||
|
/// the preamble — both the default and any user-supplied template).
|
||||||
|
/// - `override_path`: path to a user-supplied template, already
|
||||||
|
/// resolved by [`crate::config::Config`]. The `# Tools` block is
|
||||||
|
/// appended *after* the user's template so a custom preamble
|
||||||
|
/// still gets the tool descriptions the model needs.
|
||||||
|
/// - `tools`: the tools to advertise. Empty list → no `# Tools`
|
||||||
|
/// block is appended at all.
|
||||||
|
/// - `mode`: current session mode. When the mode is [`MODE_PLAN`]
|
||||||
|
/// a plan-mode addendum describing the restrictions and the
|
||||||
|
/// completion menu is appended *after* the `# Tools` block so it
|
||||||
|
/// is the last thing the model reads before user input.
|
||||||
|
/// - `plan_dir`: resolved plan directory for the cwd. Only consulted
|
||||||
|
/// when `mode == MODE_PLAN`. `None` means the plan directory could
|
||||||
|
/// not be resolved (no `HOME` / `XDG_DATA_HOME`) — the addendum
|
||||||
|
/// still renders but with a placeholder so the model knows to
|
||||||
|
/// surface the error to the user rather than guess a path.
|
||||||
|
pub fn build_system_prompt(
|
||||||
|
cwd: &Path,
|
||||||
|
override_path: Option<&Path>,
|
||||||
|
tools: &[ToolSpec],
|
||||||
|
mode: &SessionModeId,
|
||||||
|
plan_dir: Option<&Path>,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let template = match override_path {
|
||||||
|
Some(path) => std::fs::read_to_string(path)
|
||||||
|
.with_context(|| format!("read system prompt from {}", path.display()))?,
|
||||||
|
None => DEFAULT_PROMPT.to_string(),
|
||||||
|
};
|
||||||
|
let mut prompt = template.replace("{cwd}", &cwd.display().to_string());
|
||||||
|
prompt.push_str(&qwen3::render_tool_block(tools));
|
||||||
|
if mode.0.as_ref() == MODE_PLAN {
|
||||||
|
prompt.push_str(&render_plan_mode_block(plan_dir));
|
||||||
|
}
|
||||||
|
Ok(prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Plan-mode instruction block. Tells the model:
|
||||||
|
///
|
||||||
|
/// 1. Where it may write — only inside `plan_dir`.
|
||||||
|
/// 2. What it may *not* do — bash is disabled; writes outside
|
||||||
|
/// `plan_dir` are refused by the runtime.
|
||||||
|
/// 3. How to finish — emit the 3-option menu so the user can
|
||||||
|
/// switch modes and either kick off implementation (with or
|
||||||
|
/// without permission prompts) or keep iterating on the plan.
|
||||||
|
fn render_plan_mode_block(plan_dir: Option<&Path>) -> String {
|
||||||
|
let plan_path = plan_dir
|
||||||
|
.map(|p| p.display().to_string())
|
||||||
|
.unwrap_or_else(|| "<plan directory could not be resolved — tell the user>".to_string());
|
||||||
|
format!(
|
||||||
|
"\n\n# Plan mode\n\
|
||||||
|
\n\
|
||||||
|
You are in **plan mode**. Your task is to draft a written\n\
|
||||||
|
implementation plan for the user; you must NOT modify any\n\
|
||||||
|
project files or run shell commands.\n\
|
||||||
|
\n\
|
||||||
|
Rules in plan mode:\n\
|
||||||
|
\n\
|
||||||
|
- `read_file` and `list_dir` are unrestricted — use them to\n\
|
||||||
|
explore the codebase as needed.\n\
|
||||||
|
- `write_file` and `edit_file` are allowed ONLY under the\n\
|
||||||
|
plan directory: `{plan_path}`. The runtime will refuse any\n\
|
||||||
|
write outside it.\n\
|
||||||
|
- `bash` is disabled. Do not call it.\n\
|
||||||
|
\n\
|
||||||
|
Write the plan as one or more Markdown files under\n\
|
||||||
|
`{plan_path}`. Use descriptive filenames\n\
|
||||||
|
(`01-overview.md`, `02-data-model.md`, etc.). It is fine to\n\
|
||||||
|
iterate — overwrite the file when you refine a section.\n\
|
||||||
|
\n\
|
||||||
|
When the plan is complete, do NOT begin implementation.\n\
|
||||||
|
Instead, end your turn with this menu, verbatim, so the\n\
|
||||||
|
user can choose how to proceed:\n\
|
||||||
|
\n\
|
||||||
|
---\n\
|
||||||
|
**Plan complete.** To proceed, switch the session mode in\n\
|
||||||
|
the agent dropdown and send a follow-up message:\n\
|
||||||
|
\n\
|
||||||
|
1. **Bypass Permissions** — implement the plan now, skipping\n\
|
||||||
|
per-tool permission prompts.\n\
|
||||||
|
2. **Default** — implement the plan now, prompting before\n\
|
||||||
|
each write or shell command.\n\
|
||||||
|
3. **Plan** (stay here) — refine the plan; reply with the\n\
|
||||||
|
change you want and I will revise it.\n\
|
||||||
|
---\n"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::session::{MODE_DEFAULT, MODE_PLAN};
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
fn default_mode() -> SessionModeId {
|
||||||
|
SessionModeId::new(MODE_DEFAULT)
|
||||||
|
}
|
||||||
|
fn plan_mode() -> SessionModeId {
|
||||||
|
SessionModeId::new(MODE_PLAN)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_prompt_substitutes_cwd() {
|
||||||
|
let prompt =
|
||||||
|
build_system_prompt(Path::new("/home/me/proj"), None, &[], &default_mode(), None)
|
||||||
|
.unwrap();
|
||||||
|
assert!(
|
||||||
|
prompt.contains("/home/me/proj"),
|
||||||
|
"cwd not interpolated: {prompt}"
|
||||||
|
);
|
||||||
|
assert!(prompt.contains("helexa-acp"));
|
||||||
|
assert!(
|
||||||
|
!prompt.contains("{cwd}"),
|
||||||
|
"left-over placeholder in default prompt"
|
||||||
|
);
|
||||||
|
// With no tools, the # Tools block is absent.
|
||||||
|
assert!(!prompt.contains("# Tools"));
|
||||||
|
// Default mode does not get the plan-mode addendum.
|
||||||
|
assert!(!prompt.contains("# Plan mode"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tools_are_appended_in_hermes_format() {
|
||||||
|
let spec = ToolSpec {
|
||||||
|
name: "read_file".into(),
|
||||||
|
description: "Read a file.".into(),
|
||||||
|
parameters: serde_json::json!({"type":"object","properties":{}, "required":[]}),
|
||||||
|
};
|
||||||
|
let prompt =
|
||||||
|
build_system_prompt(Path::new("/x"), None, &[spec], &default_mode(), None).unwrap();
|
||||||
|
assert!(prompt.contains("# Tools"));
|
||||||
|
assert!(prompt.contains("<tools>"));
|
||||||
|
assert!(prompt.contains("\"name\":\"read_file\""));
|
||||||
|
assert!(prompt.contains("<tool_call>"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn override_path_is_read_and_templated() {
|
||||||
|
let mut tmp = tempfile_in_target("prompt.txt");
|
||||||
|
tmp.write_all(b"custom prompt for {cwd} only").unwrap();
|
||||||
|
tmp.flush().unwrap();
|
||||||
|
|
||||||
|
let path = tmp.path().to_path_buf();
|
||||||
|
drop(tmp);
|
||||||
|
|
||||||
|
let prompt = build_system_prompt(
|
||||||
|
Path::new("/etc"),
|
||||||
|
Some(path.as_path()),
|
||||||
|
&[],
|
||||||
|
&default_mode(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.expect("read override");
|
||||||
|
assert_eq!(prompt, "custom prompt for /etc only");
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(&path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn missing_override_path_errors() {
|
||||||
|
let err = build_system_prompt(
|
||||||
|
Path::new("/tmp"),
|
||||||
|
Some(Path::new("/definitely/not/a/real/path")),
|
||||||
|
&[],
|
||||||
|
&default_mode(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("read system prompt"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn plan_mode_addendum_includes_plan_dir_and_menu() {
|
||||||
|
let plan_dir = Path::new("/home/me/.local/share/helexa-acp/plans/proj-deadbeef");
|
||||||
|
let prompt = build_system_prompt(
|
||||||
|
Path::new("/home/me/proj"),
|
||||||
|
None,
|
||||||
|
&[],
|
||||||
|
&plan_mode(),
|
||||||
|
Some(plan_dir),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert!(prompt.contains("# Plan mode"));
|
||||||
|
assert!(
|
||||||
|
prompt.contains(plan_dir.to_str().unwrap()),
|
||||||
|
"plan dir not interpolated: {prompt}"
|
||||||
|
);
|
||||||
|
// The 3-option menu must be present so the model emits it verbatim.
|
||||||
|
assert!(prompt.contains("Bypass Permissions"));
|
||||||
|
assert!(prompt.contains("**Default**"));
|
||||||
|
assert!(prompt.contains("3. **Plan**"));
|
||||||
|
// Bash disabled instruction must be present.
|
||||||
|
assert!(prompt.contains("`bash` is disabled"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn plan_mode_addendum_handles_unresolved_plan_dir() {
|
||||||
|
let prompt =
|
||||||
|
build_system_prompt(Path::new("/home/me/proj"), None, &[], &plan_mode(), None).unwrap();
|
||||||
|
assert!(prompt.contains("# Plan mode"));
|
||||||
|
assert!(prompt.contains("could not be resolved"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tiny temp-file helper that doesn't pull in the `tempfile` crate.
|
||||||
|
/// Writes under `target/` so it's cleaned up by `cargo clean`.
|
||||||
|
fn tempfile_in_target(name: &str) -> TempHandle {
|
||||||
|
let base = std::env::var("CARGO_TARGET_TMPDIR")
|
||||||
|
.ok()
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.unwrap_or_else(std::env::temp_dir);
|
||||||
|
let _ = std::fs::create_dir_all(&base);
|
||||||
|
let pid = std::process::id();
|
||||||
|
let path = base.join(format!("helexa-acp-{pid}-{name}"));
|
||||||
|
let file = std::fs::File::create(&path).expect("create temp file");
|
||||||
|
TempHandle { file, path }
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TempHandle {
|
||||||
|
file: std::fs::File,
|
||||||
|
path: std::path::PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TempHandle {
|
||||||
|
fn path(&self) -> &Path {
|
||||||
|
&self.path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Write for TempHandle {
|
||||||
|
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||||
|
self.file.write(buf)
|
||||||
|
}
|
||||||
|
fn flush(&mut self) -> std::io::Result<()> {
|
||||||
|
self.file.flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1200
crates/helexa-acp/src/provider/anthropic_messages.rs
Normal file
1200
crates/helexa-acp/src/provider/anthropic_messages.rs
Normal file
File diff suppressed because it is too large
Load Diff
230
crates/helexa-acp/src/provider/mod.rs
Normal file
230
crates/helexa-acp/src/provider/mod.rs
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
//! Provider trait — the seam between the ACP-side agent loop and
|
||||||
|
//! whatever wire protocol an endpoint actually speaks.
|
||||||
|
//!
|
||||||
|
//! Every concrete provider (OpenAI chat completions, OpenAI Responses,
|
||||||
|
//! Anthropic /v1/messages, Ollama native, …) implements
|
||||||
|
//! [`Provider`]. The agent constructs a [`CompletionRequest`] using
|
||||||
|
//! provider-agnostic types and consumes a stream of
|
||||||
|
//! [`CompletionEvent`]s — neither end knows which wire format is on
|
||||||
|
//! the other side of the trait.
|
||||||
|
//!
|
||||||
|
//! Day-1 provider: [`openai_chat::OpenAIChatProvider`]. Day-N
|
||||||
|
//! providers slot in without touching `agent.rs`.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
pub mod anthropic_messages;
|
||||||
|
pub mod openai_chat;
|
||||||
|
pub mod openai_responses;
|
||||||
|
|
||||||
|
/// Provider-agnostic LLM endpoint. Implementations translate between
|
||||||
|
/// [`CompletionRequest`] / [`CompletionEvent`] and whatever wire
|
||||||
|
/// format their endpoint speaks.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Provider: Send + Sync {
|
||||||
|
/// Endpoint name as configured by the user (e.g. `"helexa"`,
|
||||||
|
/// `"openrouter"`). Used in logs and in the `endpoint:model`
|
||||||
|
/// selector.
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// List models available at this endpoint. Used to build the
|
||||||
|
/// model-picker dropdown in editor clients (Stage 4). Should
|
||||||
|
/// return quickly (cache if necessary).
|
||||||
|
#[allow(dead_code)]
|
||||||
|
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>>;
|
||||||
|
|
||||||
|
/// Run a chat completion. Returns a stream of provider-agnostic
|
||||||
|
/// events. The stream stops when the upstream finishes, when
|
||||||
|
/// `cancel` is fired, or when the stream is dropped.
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One model exposed by a provider. Constructed by `list_models` —
|
||||||
|
/// Stage 4 is when the agent loop starts consuming it for the
|
||||||
|
/// model-picker dropdown.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
pub id: String,
|
||||||
|
/// Human-friendly name, if the endpoint exposes one. Otherwise
|
||||||
|
/// `id` is used as the display name.
|
||||||
|
#[serde(default)]
|
||||||
|
pub display_name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inputs to a completion. Provider-agnostic — concrete providers
|
||||||
|
/// translate this into their wire format.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CompletionRequest {
|
||||||
|
/// Endpoint-local model id (without the `endpoint:` prefix).
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
/// Tools the model is allowed to call. Empty list means no tool
|
||||||
|
/// support advertised.
|
||||||
|
pub tools: Vec<ToolSpec>,
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
pub max_tokens: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Message {
|
||||||
|
pub role: Role,
|
||||||
|
pub content: MessageContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum Role {
|
||||||
|
System,
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
/// Tool result message. Provider impls turn this into whatever
|
||||||
|
/// shape the upstream wire format wants (OpenAI uses
|
||||||
|
/// `role: "tool"` + `tool_call_id`; Anthropic uses content blocks).
|
||||||
|
/// Stage 3 (tools) constructs this; Stage 2 never does.
|
||||||
|
Tool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum MessageContent {
|
||||||
|
/// Plain text turn (system / user / assistant). Struct variant
|
||||||
|
/// rather than newtype so the persisted JSON has an explicit
|
||||||
|
/// `text` field — that lets us use internal tagging on the
|
||||||
|
/// enum, which is incompatible with newtype-of-primitive
|
||||||
|
/// variants.
|
||||||
|
Text { text: String },
|
||||||
|
/// Mixed text + image user turn. Stage 5 introduces this when
|
||||||
|
/// Zed sends an `ImageContent` block alongside the user's prompt.
|
||||||
|
/// Providers that don't support vision should down-convert by
|
||||||
|
/// dropping image parts and concatenating text parts.
|
||||||
|
MultiPart { parts: Vec<MessagePart> },
|
||||||
|
/// Assistant turn that called one or more tools. Stage 3 starts
|
||||||
|
/// constructing this when the provider stream yields a
|
||||||
|
/// `ToolCallStart` / `ToolCallArgsDelta` sequence.
|
||||||
|
ToolCalls {
|
||||||
|
/// Optional text the assistant said alongside the tool calls.
|
||||||
|
text: Option<String>,
|
||||||
|
calls: Vec<ToolCall>,
|
||||||
|
},
|
||||||
|
/// Tool result. `tool_call_id` matches the assistant's call id.
|
||||||
|
/// Stage 3 constructs this after the tool runner finishes.
|
||||||
|
ToolResult {
|
||||||
|
tool_call_id: String,
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One part of a [`MessageContent::MultiPart`] message.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum MessagePart {
|
||||||
|
Text { text: String },
|
||||||
|
Image(ImageData),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inline image attachment. `data` is base64-encoded raw image
|
||||||
|
/// bytes; the encoder constructs an `image_url` data URI from it
|
||||||
|
/// at request time. `uri` carries any pointer the client supplied
|
||||||
|
/// (e.g. `file:///tmp/x.png`) — we keep it on the message for
|
||||||
|
/// debugging / future providers but the OpenAI encoder ignores it
|
||||||
|
/// when `data` is present (data wins, since it round-trips through
|
||||||
|
/// every wire format).
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ImageData {
|
||||||
|
pub mime_type: String,
|
||||||
|
/// Base64-encoded image bytes (no `data:` prefix, no padding
|
||||||
|
/// stripped — exactly what `ImageContent.data` carried).
|
||||||
|
pub data: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub uri: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
/// Provider-assigned id that ties the call to its result. The
|
||||||
|
/// Qwen3 wire format we use today doesn't carry this on the
|
||||||
|
/// model side (calls and results are matched positionally inside
|
||||||
|
/// a turn), so the field looks unused in the prod build — but it
|
||||||
|
/// flows through to `MessageContent::ToolResult.tool_call_id` for
|
||||||
|
/// history bookkeeping and a future strict-OpenAI backend will
|
||||||
|
/// consume it directly.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
/// JSON-encoded arguments. Kept as a string because providers
|
||||||
|
/// stream argument bytes incrementally and only validate at the
|
||||||
|
/// end; the agent decodes once the call is complete.
|
||||||
|
pub arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolSpec {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
/// JSON Schema of the arguments object.
|
||||||
|
pub parameters: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Events emitted by a provider during a streaming completion.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum CompletionEvent {
|
||||||
|
/// Incremental visible text from the assistant.
|
||||||
|
TextDelta(String),
|
||||||
|
/// Incremental "reasoning" / thought text, if the model emits one
|
||||||
|
/// (e.g. Qwen3 with `<think>` tags surfaced as a separate stream,
|
||||||
|
/// or OpenAI reasoning models).
|
||||||
|
ReasoningDelta(String),
|
||||||
|
/// A new tool call has started. Stage 2 ignores the payload; the
|
||||||
|
/// agent loop in Stage 3 reads `index` to correlate with
|
||||||
|
/// [`Self::ToolCallArgsDelta`], `id` for the eventual tool-result
|
||||||
|
/// turn, and `name` to dispatch the runner.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ToolCallStart {
|
||||||
|
index: usize,
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
},
|
||||||
|
/// More argument bytes for a tool call already announced via
|
||||||
|
/// [`Self::ToolCallStart`]. Stage 2 ignores; Stage 3 accumulates
|
||||||
|
/// the bytes by `index` until the call's arguments are complete.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ToolCallArgsDelta { index: usize, args_delta: String },
|
||||||
|
/// A `<tool_call>` block whose JSON couldn't be parsed even with
|
||||||
|
/// the qwen3 module's repair attempts. The agent surfaces this
|
||||||
|
/// as a Failed `SessionUpdate::ToolCall` card with the raw body
|
||||||
|
/// visible (so the editor renders structured failure UI rather
|
||||||
|
/// than dumping the body inline in the message pane), and feeds
|
||||||
|
/// a synthetic tool-error message back into history so the
|
||||||
|
/// model can self-correct on the next round.
|
||||||
|
MalformedToolCall { raw: String },
|
||||||
|
/// Stream finished. Carries the upstream `finish_reason` if it
|
||||||
|
/// gave one (`"stop"`, `"length"`, `"tool_calls"`, …).
|
||||||
|
Finish { reason: Option<String> },
|
||||||
|
/// Final usage stats, if the provider supplied them. Stage 2
|
||||||
|
/// matches the variant to drop it; Stage 6b (token metrics) is
|
||||||
|
/// when the payload starts being read.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
Usage(UsageStats),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Token accounting reported by the provider at the end of a stream.
|
||||||
|
/// Stage 2 doesn't surface usage anywhere — the stable `PromptResponse`
|
||||||
|
/// has no usage field, and the unstable variant is gated. Stage 6b
|
||||||
|
/// turns these on with Prometheus metrics.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct UsageStats {
|
||||||
|
pub prompt_tokens: u64,
|
||||||
|
pub completion_tokens: u64,
|
||||||
|
pub total_tokens: u64,
|
||||||
|
}
|
||||||
1002
crates/helexa-acp/src/provider/openai_chat.rs
Normal file
1002
crates/helexa-acp/src/provider/openai_chat.rs
Normal file
File diff suppressed because it is too large
Load Diff
987
crates/helexa-acp/src/provider/openai_responses.rs
Normal file
987
crates/helexa-acp/src/provider/openai_responses.rs
Normal file
@@ -0,0 +1,987 @@
|
|||||||
|
//! OpenAI Responses API (`POST /v1/responses`) provider.
|
||||||
|
//!
|
||||||
|
//! Mirror image of [`super::openai_chat`]: same `Provider` trait
|
||||||
|
//! impl, same back-pressured SSE decoder, but speaking OpenAI's
|
||||||
|
//! newer Responses surface instead of chat completions.
|
||||||
|
//!
|
||||||
|
//! Differences from the chat provider, all contained in this file:
|
||||||
|
//!
|
||||||
|
//! - **Request encoding**: history flattens into an `input` array
|
||||||
|
//! of typed items (`message`, `function_call`, `function_call_output`)
|
||||||
|
//! plus a top-level `instructions` field for the system prompt.
|
||||||
|
//! Multi-part user content stays in the same `[{type:"input_text"},
|
||||||
|
//! {type:"input_image"}]` shape neuron's `request_to_chat` already
|
||||||
|
//! accepts.
|
||||||
|
//! - **Streaming decoder**: events are named (`response.created`,
|
||||||
|
//! `response.output_text.delta`, `response.completed`, …) carried
|
||||||
|
//! on the SSE `event:` line. The chat path's `[DONE]` terminator
|
||||||
|
//! doesn't apply; the stream ends after `response.completed`.
|
||||||
|
//! - **Tool calls** plumb through the `response.output_item.added`
|
||||||
|
//! (item type `function_call`) → `response.function_call_arguments.delta`
|
||||||
|
//! → `response.function_call_arguments.done` event sequence. The
|
||||||
|
//! neuron candle harness doesn't synthesize these yet (tracked as
|
||||||
|
//! issue #6), but the decoder is wired so the day the upstream
|
||||||
|
//! does, downstream `CompletionEvent::ToolCall*` plumbing just
|
||||||
|
//! works.
|
||||||
|
//!
|
||||||
|
//! Tool-name handling: the model knows its tool descriptions via
|
||||||
|
//! the [`crate::qwen3`] system-prompt block exactly the way the chat
|
||||||
|
//! provider does. We don't echo them in the request body because
|
||||||
|
//! neuron currently ignores `tools` on /v1/responses (same as on
|
||||||
|
//! /v1/chat/completions). Once neuron honours request-side tool
|
||||||
|
//! definitions, both providers add them in the same place.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use eventsource_stream::Eventsource;
|
||||||
|
use futures::{Stream, StreamExt, stream::BoxStream};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
CompletionEvent, CompletionRequest, Message, MessageContent, MessagePart, ModelInfo, Provider,
|
||||||
|
Role, UsageStats,
|
||||||
|
};
|
||||||
|
use crate::config::EndpointConfig;
|
||||||
|
|
||||||
|
pub struct OpenAIResponsesProvider {
|
||||||
|
endpoint: EndpointConfig,
|
||||||
|
#[allow(dead_code)] // Read in `complete()`'s HTTP path; tests don't stand up a server.
|
||||||
|
api_key: Option<String>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
http: reqwest::Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAIResponsesProvider {
|
||||||
|
pub fn new(endpoint: EndpointConfig) -> anyhow::Result<Self> {
|
||||||
|
let api_key = endpoint.resolve_api_key()?;
|
||||||
|
let http = reqwest::Client::builder()
|
||||||
|
// Same generous timeout as the chat provider: cortex may
|
||||||
|
// need to cold-load a model before serving the first
|
||||||
|
// chunk, which can be tens of seconds. Cancellation
|
||||||
|
// handles early termination, not timeout.
|
||||||
|
.timeout(std::time::Duration::from_secs(600))
|
||||||
|
.build()?;
|
||||||
|
Ok(Self {
|
||||||
|
endpoint,
|
||||||
|
api_key,
|
||||||
|
http,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for OpenAIResponsesProvider {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.endpoint.name
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>> {
|
||||||
|
let mut req = self.http.get(self.endpoint.models_url());
|
||||||
|
if let Some(key) = &self.api_key {
|
||||||
|
req = req.bearer_auth(key);
|
||||||
|
}
|
||||||
|
let resp = req
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("{} list_models: {e}", self.endpoint.name))?;
|
||||||
|
let status = resp.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
anyhow::bail!(
|
||||||
|
"{} list_models returned {}: {}",
|
||||||
|
self.endpoint.name,
|
||||||
|
status,
|
||||||
|
body
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let body: WireModelsResponse = resp.json().await?;
|
||||||
|
Ok(body
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(|m| ModelInfo {
|
||||||
|
id: m.id,
|
||||||
|
display_name: None,
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>> {
|
||||||
|
let body = encode_request(&request);
|
||||||
|
tracing::debug!(
|
||||||
|
endpoint = %self.endpoint.name,
|
||||||
|
url = %self.endpoint.responses_url(),
|
||||||
|
body = %serde_json::to_string(&body).unwrap_or_else(|_| "<unserializable>".into()),
|
||||||
|
"POST /responses"
|
||||||
|
);
|
||||||
|
let mut req = self.http.post(self.endpoint.responses_url()).json(&body);
|
||||||
|
if let Some(key) = &self.api_key {
|
||||||
|
req = req.bearer_auth(key);
|
||||||
|
}
|
||||||
|
let resp = req
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("{} responses send: {e}", self.endpoint.name))?;
|
||||||
|
let status = resp.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let body = resp.text().await.unwrap_or_default();
|
||||||
|
anyhow::bail!(
|
||||||
|
"{} responses returned {}: {}",
|
||||||
|
self.endpoint.name,
|
||||||
|
status,
|
||||||
|
body
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let sse = resp.bytes_stream().eventsource();
|
||||||
|
let stream = decode_stream(sse, cancel);
|
||||||
|
Ok(Box::pin(stream))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Request encoding ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn encode_request(req: &CompletionRequest) -> Value {
|
||||||
|
// Pull the system messages out of history into a single
|
||||||
|
// `instructions` string — the Responses API expects them there,
|
||||||
|
// not inline as an `input` item. Multiple system messages
|
||||||
|
// concatenate with blank lines so we don't lose ordering.
|
||||||
|
let mut instructions: Vec<String> = Vec::new();
|
||||||
|
let mut input_items: Vec<Value> = Vec::new();
|
||||||
|
for msg in &req.messages {
|
||||||
|
if msg.role == Role::System
|
||||||
|
&& let MessageContent::Text { text } = &msg.content
|
||||||
|
{
|
||||||
|
instructions.push(text.clone());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Some(item) = encode_message_as_input_item(msg) {
|
||||||
|
input_items.push(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut body = json!({
|
||||||
|
"model": req.model,
|
||||||
|
"input": input_items,
|
||||||
|
"stream": true,
|
||||||
|
});
|
||||||
|
if let Value::Object(map) = &mut body {
|
||||||
|
if !instructions.is_empty() {
|
||||||
|
map.insert(
|
||||||
|
"instructions".into(),
|
||||||
|
Value::String(instructions.join("\n\n")),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if let Some(t) = req.temperature {
|
||||||
|
map.insert("temperature".into(), json!(t));
|
||||||
|
}
|
||||||
|
if let Some(p) = req.top_p {
|
||||||
|
map.insert("top_p".into(), json!(p));
|
||||||
|
}
|
||||||
|
if let Some(m) = req.max_tokens {
|
||||||
|
// Responses calls it `max_output_tokens`; preserve the
|
||||||
|
// semantic (response cap) when we translate.
|
||||||
|
map.insert("max_output_tokens".into(), json!(m));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_message_as_input_item(msg: &Message) -> Option<Value> {
|
||||||
|
match (msg.role, &msg.content) {
|
||||||
|
(Role::System, _) => None, // handled out-of-band as `instructions`
|
||||||
|
(Role::User, MessageContent::Text { text }) => Some(json!({
|
||||||
|
"type": "message",
|
||||||
|
"role": "user",
|
||||||
|
"content": text,
|
||||||
|
})),
|
||||||
|
(Role::User, MessageContent::MultiPart { parts }) => Some(json!({
|
||||||
|
"type": "message",
|
||||||
|
"role": "user",
|
||||||
|
"content": encode_user_parts(parts),
|
||||||
|
})),
|
||||||
|
(Role::Assistant, MessageContent::Text { text }) => Some(json!({
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{
|
||||||
|
"type": "output_text",
|
||||||
|
"text": text,
|
||||||
|
"annotations": [],
|
||||||
|
}],
|
||||||
|
})),
|
||||||
|
(Role::Assistant, MessageContent::ToolCalls { text, calls }) => {
|
||||||
|
// Assistant turns that called tools become a sequence of
|
||||||
|
// items: an optional `message` (any prose alongside the
|
||||||
|
// call) followed by one `function_call` per call. Mirrors
|
||||||
|
// OpenAI Responses' "each item is one structural slot"
|
||||||
|
// shape.
|
||||||
|
//
|
||||||
|
// We can't return multiple items from one call site, so
|
||||||
|
// we encode this by side-stuffing additional items into a
|
||||||
|
// single composite value and have the caller flatten —
|
||||||
|
// but that complicates the API. Easier: build the array
|
||||||
|
// ourselves in the caller path. For now, emit just the
|
||||||
|
// function_calls (the assistant's prose lives in the next
|
||||||
|
// turn's chat history anyway because the model isn't
|
||||||
|
// looking back at its own previous narration). If the
|
||||||
|
// text is non-empty AND we have calls, we lose the text;
|
||||||
|
// qwen3 rarely emits prose alongside tool calls so this
|
||||||
|
// is a deliberate simplification — revisit if it bites.
|
||||||
|
let _ = text;
|
||||||
|
// Take the first call only for the moment; multi-call
|
||||||
|
// turns would need the caller-flattening above.
|
||||||
|
let call = calls.first()?;
|
||||||
|
Some(json!({
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": call.id,
|
||||||
|
"name": call.name,
|
||||||
|
"arguments": call.arguments,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
Role::Tool,
|
||||||
|
MessageContent::ToolResult {
|
||||||
|
tool_call_id,
|
||||||
|
content,
|
||||||
|
},
|
||||||
|
) => Some(json!({
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": tool_call_id,
|
||||||
|
"output": content,
|
||||||
|
})),
|
||||||
|
(role, content) => {
|
||||||
|
tracing::warn!(
|
||||||
|
?role,
|
||||||
|
?content,
|
||||||
|
"openai_responses: unexpected (role, content) shape"
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_user_parts(parts: &[MessagePart]) -> Value {
|
||||||
|
let items: Vec<Value> = parts
|
||||||
|
.iter()
|
||||||
|
.map(|p| match p {
|
||||||
|
MessagePart::Text { text } => json!({"type": "input_text", "text": text}),
|
||||||
|
MessagePart::Image(img) => json!({
|
||||||
|
"type": "input_image",
|
||||||
|
"image_url": format!("data:{};base64,{}", img.mime_type, img.data),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Value::Array(items)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Wire types ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[allow(dead_code)] // fields read only when list_models runs against a real endpoint
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct WireModelsResponse {
|
||||||
|
data: Vec<WireModelObject>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct WireModelObject {
|
||||||
|
id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE event payload shapes. We only model the fields we care about;
|
||||||
|
// `#[serde(default)]` + `Option` everywhere else lets the upstream
|
||||||
|
// add optional fields without breaking deserialise.
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
struct OutputItemAddedEvent {
|
||||||
|
#[serde(default)]
|
||||||
|
output_index: u32,
|
||||||
|
item: OutputItem,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
enum OutputItem {
|
||||||
|
Message {
|
||||||
|
#[serde(default)]
|
||||||
|
id: Option<String>,
|
||||||
|
},
|
||||||
|
FunctionCall {
|
||||||
|
#[serde(default)]
|
||||||
|
id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
call_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
name: Option<String>,
|
||||||
|
/// Some upstreams populate `arguments` already on the
|
||||||
|
/// `output_item.added` event for a fully-buffered tool call
|
||||||
|
/// (i.e. when the model finalised the call before the SSE
|
||||||
|
/// flush). Capture it so we can emit a single args delta.
|
||||||
|
#[serde(default)]
|
||||||
|
arguments: Option<String>,
|
||||||
|
},
|
||||||
|
/// `reasoning`, `web_search_call`, etc. We capture-and-ignore
|
||||||
|
/// any item we don't model; the decoder still emits the
|
||||||
|
/// outer events correctly.
|
||||||
|
#[serde(other)]
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
struct OutputTextDeltaEvent {
|
||||||
|
#[serde(default)]
|
||||||
|
item_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
output_index: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
delta: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
struct FunctionCallArgumentsDeltaEvent {
|
||||||
|
#[serde(default)]
|
||||||
|
item_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
output_index: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
delta: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
struct ResponseCompletedEvent {
|
||||||
|
response: ResponseShell,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
struct ResponseShell {
|
||||||
|
#[serde(default)]
|
||||||
|
status: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
usage: Option<WireUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
struct WireUsage {
|
||||||
|
#[serde(default)]
|
||||||
|
input_tokens: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
output_tokens: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
total_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Streaming decoder ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Translate the named-event Responses SSE into the provider-agnostic
|
||||||
|
/// [`CompletionEvent`] stream the agent loop expects. The decoder
|
||||||
|
/// holds per-stream state — output_index → tool-call-index plus
|
||||||
|
/// the next available tool-call slot — so it can fire
|
||||||
|
/// `ToolCallStart` exactly once per item.
|
||||||
|
fn decode_stream<S>(
|
||||||
|
sse: S,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
) -> impl Stream<Item = anyhow::Result<CompletionEvent>>
|
||||||
|
where
|
||||||
|
S: Stream<
|
||||||
|
Item = Result<
|
||||||
|
eventsource_stream::Event,
|
||||||
|
eventsource_stream::EventStreamError<reqwest::Error>,
|
||||||
|
>,
|
||||||
|
> + Send
|
||||||
|
+ 'static,
|
||||||
|
{
|
||||||
|
async_stream::stream! {
|
||||||
|
let mut sse = Box::pin(sse);
|
||||||
|
// Maps an output_index that's a function_call to the tool-call
|
||||||
|
// slot we hand downstream. Lets us correlate later
|
||||||
|
// `function_call_arguments.delta` events back to the index
|
||||||
|
// we already announced on `output_item.added`.
|
||||||
|
let mut tool_index_by_output: HashMap<u32, usize> = HashMap::new();
|
||||||
|
let mut next_tool_index: usize = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
biased;
|
||||||
|
_ = cancel.cancelled() => {
|
||||||
|
tracing::debug!("openai_responses: cancellation requested, ending stream");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
next = sse.next() => {
|
||||||
|
let Some(event) = next else { break };
|
||||||
|
let event = match event {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(anyhow::anyhow!("SSE transport: {e}"));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Event name lives on `event.event`; data is JSON.
|
||||||
|
let event_name = event.event.as_str();
|
||||||
|
let data = event.data.as_str();
|
||||||
|
match event_name {
|
||||||
|
"response.output_text.delta" => {
|
||||||
|
match serde_json::from_str::<OutputTextDeltaEvent>(data) {
|
||||||
|
Ok(d) if !d.delta.is_empty() => {
|
||||||
|
yield Ok(CompletionEvent::TextDelta(d.delta));
|
||||||
|
}
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
error = %e,
|
||||||
|
raw = %data,
|
||||||
|
"openai_responses: failed to parse output_text.delta; skipping"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"response.output_item.added" => {
|
||||||
|
match serde_json::from_str::<OutputItemAddedEvent>(data) {
|
||||||
|
Ok(ev) => {
|
||||||
|
if let OutputItem::FunctionCall {
|
||||||
|
id,
|
||||||
|
call_id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
} = ev.item
|
||||||
|
{
|
||||||
|
let idx = next_tool_index;
|
||||||
|
next_tool_index += 1;
|
||||||
|
tool_index_by_output.insert(ev.output_index, idx);
|
||||||
|
// Prefer the user-facing
|
||||||
|
// `call_id` (what gets paired
|
||||||
|
// with tool results) over the
|
||||||
|
// internal item `id` when
|
||||||
|
// both are present. Falls
|
||||||
|
// back to a synthetic id so
|
||||||
|
// history bookkeeping never
|
||||||
|
// breaks.
|
||||||
|
let final_id = call_id
|
||||||
|
.or(id)
|
||||||
|
.unwrap_or_else(|| format!("call_{idx}"));
|
||||||
|
let final_name = name.unwrap_or_default();
|
||||||
|
yield Ok(CompletionEvent::ToolCallStart {
|
||||||
|
index: idx,
|
||||||
|
id: final_id,
|
||||||
|
name: final_name,
|
||||||
|
});
|
||||||
|
// Some upstreams attach the
|
||||||
|
// fully-buffered arguments on
|
||||||
|
// the `output_item.added`
|
||||||
|
// event itself (rare; happens
|
||||||
|
// when the model finalised
|
||||||
|
// before the SSE flush).
|
||||||
|
// Emit as a single args
|
||||||
|
// delta if present.
|
||||||
|
if let Some(args) = arguments
|
||||||
|
&& !args.is_empty()
|
||||||
|
{
|
||||||
|
yield Ok(CompletionEvent::ToolCallArgsDelta {
|
||||||
|
index: idx,
|
||||||
|
args_delta: args,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
error = %e,
|
||||||
|
raw = %data,
|
||||||
|
"openai_responses: failed to parse output_item.added; skipping"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"response.function_call_arguments.delta" => {
|
||||||
|
match serde_json::from_str::<FunctionCallArgumentsDeltaEvent>(data) {
|
||||||
|
Ok(ev) => {
|
||||||
|
let Some(&idx) = tool_index_by_output.get(&ev.output_index)
|
||||||
|
else {
|
||||||
|
// Args delta for an item we
|
||||||
|
// never saw an `output_item.added`
|
||||||
|
// for. Could happen if the
|
||||||
|
// upstream reordered events;
|
||||||
|
// log + skip.
|
||||||
|
tracing::warn!(
|
||||||
|
output_index = ev.output_index,
|
||||||
|
"openai_responses: function_call_arguments.delta for unknown output_index"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if !ev.delta.is_empty() {
|
||||||
|
yield Ok(CompletionEvent::ToolCallArgsDelta {
|
||||||
|
index: idx,
|
||||||
|
args_delta: ev.delta,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
error = %e,
|
||||||
|
raw = %data,
|
||||||
|
"openai_responses: failed to parse function_call_arguments.delta; skipping"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"response.completed" => {
|
||||||
|
// Final event. Pull usage + status off
|
||||||
|
// the response shell. Status maps:
|
||||||
|
// "completed" → no special handling
|
||||||
|
// (caller treats as EndTurn),
|
||||||
|
// "incomplete" → length stop.
|
||||||
|
let (reason, usage) =
|
||||||
|
match serde_json::from_str::<ResponseCompletedEvent>(data) {
|
||||||
|
Ok(ev) => {
|
||||||
|
let reason = match ev.response.status.as_deref() {
|
||||||
|
Some("incomplete") => Some("length".to_string()),
|
||||||
|
_ => Some("stop".to_string()),
|
||||||
|
};
|
||||||
|
let usage = ev.response.usage.map(|u| UsageStats {
|
||||||
|
prompt_tokens: u.input_tokens,
|
||||||
|
completion_tokens: u.output_tokens,
|
||||||
|
total_tokens: u.total_tokens,
|
||||||
|
});
|
||||||
|
(reason, usage)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
error = %e,
|
||||||
|
raw = %data,
|
||||||
|
"openai_responses: failed to parse response.completed; ending stream with EndTurn"
|
||||||
|
);
|
||||||
|
(Some("stop".to_string()), None)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Some(u) = usage {
|
||||||
|
yield Ok(CompletionEvent::Usage(u));
|
||||||
|
}
|
||||||
|
yield Ok(CompletionEvent::Finish { reason });
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Bookkeeping events we don't need to surface:
|
||||||
|
// response.created, response.in_progress,
|
||||||
|
// response.content_part.added/.done,
|
||||||
|
// response.output_text.done,
|
||||||
|
// response.output_item.done,
|
||||||
|
// response.function_call_arguments.done,
|
||||||
|
// response.reasoning_*. Logged at debug for
|
||||||
|
// wire-tracing.
|
||||||
|
other => {
|
||||||
|
tracing::trace!(
|
||||||
|
event = other,
|
||||||
|
"openai_responses: bookkeeping event"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::provider::ToolCall;
|
||||||
|
use crate::provider::{ImageData, MessagePart};
|
||||||
|
use futures::stream;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
fn ep() -> EndpointConfig {
|
||||||
|
EndpointConfig {
|
||||||
|
name: "test".into(),
|
||||||
|
base_url: Url::parse("http://localhost:9999/v1").unwrap(),
|
||||||
|
wire_api: crate::config::WireApi::OpenAiResponses,
|
||||||
|
default_model: None,
|
||||||
|
api_key: None,
|
||||||
|
api_key_env: None,
|
||||||
|
max_tokens: None,
|
||||||
|
context_window: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── encode_request ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn system_messages_collapse_to_instructions() {
|
||||||
|
let req = CompletionRequest {
|
||||||
|
model: "m".into(),
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: Role::System,
|
||||||
|
content: MessageContent::Text {
|
||||||
|
text: "you are helpful".into(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: MessageContent::Text { text: "hi".into() },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools: vec![],
|
||||||
|
temperature: Some(0.7),
|
||||||
|
top_p: None,
|
||||||
|
max_tokens: Some(256),
|
||||||
|
};
|
||||||
|
let body = encode_request(&req);
|
||||||
|
assert_eq!(body["model"], "m");
|
||||||
|
assert_eq!(body["instructions"], "you are helpful");
|
||||||
|
assert_eq!(body["stream"], true);
|
||||||
|
assert_eq!(body["max_output_tokens"], 256);
|
||||||
|
assert_eq!(body["temperature"], 0.7);
|
||||||
|
let input = body["input"].as_array().unwrap();
|
||||||
|
// System message NOT echoed in input — it's only in
|
||||||
|
// instructions.
|
||||||
|
assert_eq!(input.len(), 1);
|
||||||
|
assert_eq!(input[0]["type"], "message");
|
||||||
|
assert_eq!(input[0]["role"], "user");
|
||||||
|
assert_eq!(input[0]["content"], "hi");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn multiple_system_messages_concatenate() {
|
||||||
|
let req = CompletionRequest {
|
||||||
|
model: "m".into(),
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: Role::System,
|
||||||
|
content: MessageContent::Text {
|
||||||
|
text: "first".into(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: Role::System,
|
||||||
|
content: MessageContent::Text {
|
||||||
|
text: "second".into(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: MessageContent::Text { text: "hi".into() },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools: vec![],
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
max_tokens: None,
|
||||||
|
};
|
||||||
|
let body = encode_request(&req);
|
||||||
|
assert_eq!(body["instructions"], "first\n\nsecond");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn user_multipart_becomes_input_parts_array() {
|
||||||
|
let req = CompletionRequest {
|
||||||
|
model: "vl".into(),
|
||||||
|
messages: vec![Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: MessageContent::MultiPart {
|
||||||
|
parts: vec![
|
||||||
|
MessagePart::Text {
|
||||||
|
text: "what's in this?".into(),
|
||||||
|
},
|
||||||
|
MessagePart::Image(ImageData {
|
||||||
|
mime_type: "image/png".into(),
|
||||||
|
data: "AAA=".into(),
|
||||||
|
uri: None,
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
tools: vec![],
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
max_tokens: None,
|
||||||
|
};
|
||||||
|
let body = encode_request(&req);
|
||||||
|
let content = &body["input"][0]["content"].as_array().unwrap().clone();
|
||||||
|
assert_eq!(content.len(), 2);
|
||||||
|
assert_eq!(content[0]["type"], "input_text");
|
||||||
|
assert_eq!(content[0]["text"], "what's in this?");
|
||||||
|
assert_eq!(content[1]["type"], "input_image");
|
||||||
|
assert_eq!(content[1]["image_url"], "data:image/png;base64,AAA=");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn assistant_text_becomes_output_text_content_part() {
|
||||||
|
let req = CompletionRequest {
|
||||||
|
model: "m".into(),
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: MessageContent::Text { text: "hi".into() },
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: MessageContent::Text {
|
||||||
|
text: "hello there".into(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: MessageContent::Text {
|
||||||
|
text: "more".into(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools: vec![],
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
max_tokens: None,
|
||||||
|
};
|
||||||
|
let body = encode_request(&req);
|
||||||
|
let input = body["input"].as_array().unwrap();
|
||||||
|
assert_eq!(input.len(), 3);
|
||||||
|
assert_eq!(input[1]["type"], "message");
|
||||||
|
assert_eq!(input[1]["role"], "assistant");
|
||||||
|
assert_eq!(input[1]["content"][0]["type"], "output_text");
|
||||||
|
assert_eq!(input[1]["content"][0]["text"], "hello there");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tool_calls_and_results_round_trip_via_function_call_items() {
|
||||||
|
let req = CompletionRequest {
|
||||||
|
model: "m".into(),
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: MessageContent::ToolCalls {
|
||||||
|
text: None,
|
||||||
|
calls: vec![ToolCall {
|
||||||
|
id: "call_42".into(),
|
||||||
|
name: "read_file".into(),
|
||||||
|
arguments: r#"{"path":"/etc/hostname"}"#.into(),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: Role::Tool,
|
||||||
|
content: MessageContent::ToolResult {
|
||||||
|
tool_call_id: "call_42".into(),
|
||||||
|
content: "host".into(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools: vec![],
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
max_tokens: None,
|
||||||
|
};
|
||||||
|
let body = encode_request(&req);
|
||||||
|
let input = body["input"].as_array().unwrap();
|
||||||
|
assert_eq!(input.len(), 2);
|
||||||
|
assert_eq!(input[0]["type"], "function_call");
|
||||||
|
assert_eq!(input[0]["call_id"], "call_42");
|
||||||
|
assert_eq!(input[0]["name"], "read_file");
|
||||||
|
assert_eq!(input[0]["arguments"], r#"{"path":"/etc/hostname"}"#);
|
||||||
|
assert_eq!(input[1]["type"], "function_call_output");
|
||||||
|
assert_eq!(input[1]["call_id"], "call_42");
|
||||||
|
assert_eq!(input[1]["output"], "host");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── decode_stream ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn sse_event(name: &str, data: &str) -> eventsource_stream::Event {
|
||||||
|
eventsource_stream::Event {
|
||||||
|
id: String::new(),
|
||||||
|
retry: None,
|
||||||
|
event: name.into(),
|
||||||
|
data: data.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn collect_events(
|
||||||
|
items: Vec<eventsource_stream::Event>,
|
||||||
|
) -> Vec<anyhow::Result<CompletionEvent>> {
|
||||||
|
let sse = stream::iter(
|
||||||
|
items
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok::<_, eventsource_stream::EventStreamError<reqwest::Error>>),
|
||||||
|
);
|
||||||
|
let decoded = decode_stream(sse, CancellationToken::new());
|
||||||
|
decoded.collect().await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn decodes_text_then_finish() {
|
||||||
|
let events = collect_events(vec![
|
||||||
|
sse_event("response.created", "{}"),
|
||||||
|
sse_event(
|
||||||
|
"response.output_text.delta",
|
||||||
|
r#"{"item_id":"msg_1","output_index":0,"delta":"hel"}"#,
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
"response.output_text.delta",
|
||||||
|
r#"{"item_id":"msg_1","output_index":0,"delta":"lo"}"#,
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
"response.completed",
|
||||||
|
r#"{"response":{"status":"completed","usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}"#,
|
||||||
|
),
|
||||||
|
])
|
||||||
|
.await;
|
||||||
|
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||||
|
let mut iter = events.into_iter();
|
||||||
|
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "hel"));
|
||||||
|
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "lo"));
|
||||||
|
assert!(matches!(iter.next(), Some(CompletionEvent::Usage(u)) if u.total_tokens == 5));
|
||||||
|
assert!(matches!(
|
||||||
|
iter.next(),
|
||||||
|
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "stop"
|
||||||
|
));
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn empty_delta_is_dropped() {
|
||||||
|
let events = collect_events(vec![
|
||||||
|
sse_event(
|
||||||
|
"response.output_text.delta",
|
||||||
|
r#"{"item_id":"m","output_index":0,"delta":""}"#,
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
"response.completed",
|
||||||
|
r#"{"response":{"status":"completed"}}"#,
|
||||||
|
),
|
||||||
|
])
|
||||||
|
.await;
|
||||||
|
let mut completion_events = events.into_iter().map(|r| r.unwrap());
|
||||||
|
// First event MUST be the Finish — the empty delta dropped.
|
||||||
|
assert!(matches!(
|
||||||
|
completion_events.next(),
|
||||||
|
Some(CompletionEvent::Finish { .. })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn incomplete_status_maps_to_length_finish_reason() {
|
||||||
|
let events = collect_events(vec![sse_event(
|
||||||
|
"response.completed",
|
||||||
|
r#"{"response":{"status":"incomplete"}}"#,
|
||||||
|
)])
|
||||||
|
.await;
|
||||||
|
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||||
|
assert!(matches!(
|
||||||
|
events.last(),
|
||||||
|
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "length"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn function_call_items_emit_toolcall_events() {
|
||||||
|
let events = collect_events(vec![
|
||||||
|
sse_event(
|
||||||
|
"response.output_item.added",
|
||||||
|
r#"{"output_index":0,"item":{"type":"function_call","id":"item_1","call_id":"call_xyz","name":"read_file"}}"#,
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
"response.function_call_arguments.delta",
|
||||||
|
r#"{"item_id":"item_1","output_index":0,"delta":"{\"path"}"#,
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
"response.function_call_arguments.delta",
|
||||||
|
r#"{"item_id":"item_1","output_index":0,"delta":"\":\"/etc/hostname\"}"}"#,
|
||||||
|
),
|
||||||
|
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
|
||||||
|
])
|
||||||
|
.await;
|
||||||
|
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||||
|
let mut iter = events.into_iter();
|
||||||
|
assert!(matches!(
|
||||||
|
iter.next(),
|
||||||
|
Some(CompletionEvent::ToolCallStart { index: 0, ref id, ref name })
|
||||||
|
if id == "call_xyz" && name == "read_file"
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
iter.next(),
|
||||||
|
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
|
||||||
|
if args_delta == r#"{"path"#
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
iter.next(),
|
||||||
|
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
|
||||||
|
if args_delta == r#"":"/etc/hostname"}"#
|
||||||
|
));
|
||||||
|
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn function_call_added_with_inline_arguments_emits_single_args_delta() {
|
||||||
|
// Some upstreams (rare) include the fully-buffered arguments
|
||||||
|
// on the `output_item.added` event when the model finalised
|
||||||
|
// the call before SSE flush. Verify both ToolCallStart and a
|
||||||
|
// single args delta fire.
|
||||||
|
let events = collect_events(vec![
|
||||||
|
sse_event(
|
||||||
|
"response.output_item.added",
|
||||||
|
r#"{"output_index":0,"item":{"type":"function_call","call_id":"call_a","name":"f","arguments":"{\"x\":1}"}}"#,
|
||||||
|
),
|
||||||
|
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
|
||||||
|
])
|
||||||
|
.await;
|
||||||
|
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||||
|
let mut iter = events.into_iter();
|
||||||
|
assert!(matches!(
|
||||||
|
iter.next(),
|
||||||
|
Some(CompletionEvent::ToolCallStart { .. })
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
iter.next(),
|
||||||
|
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
|
||||||
|
if args_delta == r#"{"x":1}"#
|
||||||
|
));
|
||||||
|
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn cancellation_ends_stream_promptly() {
|
||||||
|
// Hand the decoder an empty stream + a triggered cancellation
|
||||||
|
// token; it should terminate without yielding anything.
|
||||||
|
let sse = stream::iter(Vec::<
|
||||||
|
Result<eventsource_stream::Event, eventsource_stream::EventStreamError<reqwest::Error>>,
|
||||||
|
>::new());
|
||||||
|
let cancel = CancellationToken::new();
|
||||||
|
cancel.cancel();
|
||||||
|
let decoded = decode_stream(sse, cancel);
|
||||||
|
let events: Vec<_> = decoded.collect().await;
|
||||||
|
assert!(events.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn malformed_event_payload_is_skipped() {
|
||||||
|
let events = collect_events(vec![
|
||||||
|
sse_event("response.output_text.delta", "{not valid json"),
|
||||||
|
sse_event(
|
||||||
|
"response.output_text.delta",
|
||||||
|
r#"{"item_id":"m","output_index":0,"delta":"ok"}"#,
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
"response.completed",
|
||||||
|
r#"{"response":{"status":"completed"}}"#,
|
||||||
|
),
|
||||||
|
])
|
||||||
|
.await;
|
||||||
|
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||||
|
// First text delta dropped; second one fires.
|
||||||
|
assert!(
|
||||||
|
events
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, CompletionEvent::TextDelta(t) if t == "ok"))
|
||||||
|
);
|
||||||
|
// No errors yielded (parse failures are warn-and-skip).
|
||||||
|
assert!(
|
||||||
|
events
|
||||||
|
.iter()
|
||||||
|
.all(|e| !matches!(e, CompletionEvent::Finish { reason: None }))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn provider_construction_is_cheap() {
|
||||||
|
let _ = OpenAIResponsesProvider::new(ep()).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
1018
crates/helexa-acp/src/qwen3.rs
Normal file
1018
crates/helexa-acp/src/qwen3.rs
Normal file
File diff suppressed because it is too large
Load Diff
188
crates/helexa-acp/src/session.rs
Normal file
188
crates/helexa-acp/src/session.rs
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
//! Per-session state for the ACP agent loop.
|
||||||
|
//!
|
||||||
|
//! Concurrency:
|
||||||
|
//!
|
||||||
|
//! - [`SessionStore`] is an `Arc<RwLock<HashMap<SessionId, …>>>`. The map
|
||||||
|
//! itself is read-mostly: it changes only on `session/new` and never
|
||||||
|
//! shrinks during Stage 2, so an `RwLock` keeps concurrent reads
|
||||||
|
//! contention-free.
|
||||||
|
//! - Each session is wrapped in its own `Arc<Mutex<SessionState>>`. Holding
|
||||||
|
//! one session's lock doesn't block requests against any other session,
|
||||||
|
//! which matters once a client opens multiple sessions in parallel.
|
||||||
|
//!
|
||||||
|
//! All operations hold a lock only long enough to copy out (or mutate) the
|
||||||
|
//! state they need — never across an `await` that drives the upstream
|
||||||
|
//! provider stream.
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use agent_client_protocol::schema::{SessionId, SessionModeId};
|
||||||
|
use tokio::sync::{Mutex, RwLock};
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use crate::provider::Message;
|
||||||
|
|
||||||
|
/// Mode id advertised as the gated default. Writes / bash prompt for
|
||||||
|
/// permission via `session/request_permission`.
|
||||||
|
pub const MODE_DEFAULT: &str = "default";
|
||||||
|
|
||||||
|
/// Mode id advertised as "auto-allow everything". Matches the
|
||||||
|
/// favorite name (`bypassPermissions`) Zed clients tend to reference.
|
||||||
|
pub const MODE_BYPASS: &str = "bypassPermissions";
|
||||||
|
|
||||||
|
/// Mode id for read-and-plan-only operation. The model may read files
|
||||||
|
/// and list directories freely, may write *only* into the per-project
|
||||||
|
/// plan directory under `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`,
|
||||||
|
/// and cannot run shell commands. Designed for "draft the
|
||||||
|
/// implementation plan, then I'll review and let you execute" flows.
|
||||||
|
pub const MODE_PLAN: &str = "plan";
|
||||||
|
|
||||||
|
/// State carried for a single ACP session.
|
||||||
|
///
|
||||||
|
/// Mutated under `Mutex<SessionState>`; never share a clone across
|
||||||
|
/// tasks expecting to see the same `cancel` token — clone the token
|
||||||
|
/// explicitly when handing it to the streaming task.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SessionState {
|
||||||
|
/// Conversation history in chronological order (user / assistant
|
||||||
|
/// turns). The system prompt is *not* stored here — it's built
|
||||||
|
/// fresh per request so any cwd / config changes take effect.
|
||||||
|
pub history: Vec<Message>,
|
||||||
|
/// Working directory the client opened the session against. Used
|
||||||
|
/// by [`crate::prompt::build_system_prompt`] and (Stage 3) by
|
||||||
|
/// filesystem tools.
|
||||||
|
pub cwd: PathBuf,
|
||||||
|
/// Currently-selected model id. Format is either a bare model id
|
||||||
|
/// (resolved against the default endpoint) or `endpoint:model`.
|
||||||
|
/// Mutated by `session/set_model` in Stage 4; Stage 2 sets it
|
||||||
|
/// once at session creation and never changes it.
|
||||||
|
pub model_id: String,
|
||||||
|
/// Cancellation handle for the in-flight prompt, if any. A fresh
|
||||||
|
/// token is installed at the start of every `session/prompt`
|
||||||
|
/// request; `session/cancel` fires this one. Between prompts the
|
||||||
|
/// token is "spent" — firing it does nothing — which is fine,
|
||||||
|
/// `session/cancel` is a no-op when there's nothing to cancel.
|
||||||
|
pub cancel: CancellationToken,
|
||||||
|
/// Permission gating mode. Stage 3 advertises two ids in
|
||||||
|
/// `NewSessionResponse.modes`: [`MODE_DEFAULT`] (writes / bash
|
||||||
|
/// prompt the user) and [`MODE_BYPASS`] (auto-allow). Mutated by
|
||||||
|
/// `session/set_mode`.
|
||||||
|
pub mode_id: SessionModeId,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionState {
|
||||||
|
pub fn new(cwd: PathBuf, model_id: String) -> Self {
|
||||||
|
Self {
|
||||||
|
history: Vec::new(),
|
||||||
|
cwd,
|
||||||
|
model_id,
|
||||||
|
cancel: CancellationToken::new(),
|
||||||
|
mode_id: SessionModeId::new(MODE_DEFAULT),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Concurrent map of live sessions.
|
||||||
|
///
|
||||||
|
/// Cloning is cheap (`Arc` bump). Pass clones into every handler that
|
||||||
|
/// needs session access; never hold a clone across an `.await` that
|
||||||
|
/// could outlive the request.
|
||||||
|
pub type SessionStore = Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>;
|
||||||
|
|
||||||
|
/// Fresh, empty session store.
|
||||||
|
pub fn new_store() -> SessionStore {
|
||||||
|
Arc::new(RwLock::new(HashMap::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up a session by id. Returns `None` if no such session is registered.
|
||||||
|
pub async fn get(store: &SessionStore, id: &SessionId) -> Option<Arc<Mutex<SessionState>>> {
|
||||||
|
store.read().await.get(id).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a fresh session. Overwrites any prior entry with the same id
|
||||||
|
/// (which should never happen — ids are uniquely generated by the agent).
|
||||||
|
pub async fn insert(store: &SessionStore, id: SessionId, state: SessionState) {
|
||||||
|
store.write().await.insert(id, Arc::new(Mutex::new(state)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::provider::{MessageContent, Role};
|
||||||
|
|
||||||
|
fn id(s: &str) -> SessionId {
|
||||||
|
SessionId::new(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn insert_then_get_round_trip() {
|
||||||
|
let store = new_store();
|
||||||
|
let state = SessionState::new(PathBuf::from("/tmp"), "m".into());
|
||||||
|
insert(&store, id("s1"), state).await;
|
||||||
|
let got = get(&store, &id("s1")).await.expect("session present");
|
||||||
|
let locked = got.lock().await;
|
||||||
|
assert_eq!(locked.cwd, PathBuf::from("/tmp"));
|
||||||
|
assert_eq!(locked.model_id, "m");
|
||||||
|
assert!(locked.history.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn missing_session_is_none() {
|
||||||
|
let store = new_store();
|
||||||
|
assert!(get(&store, &id("nope")).await.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn history_is_per_session() {
|
||||||
|
let store = new_store();
|
||||||
|
insert(
|
||||||
|
&store,
|
||||||
|
id("a"),
|
||||||
|
SessionState::new(PathBuf::from("/a"), "m".into()),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
insert(
|
||||||
|
&store,
|
||||||
|
id("b"),
|
||||||
|
SessionState::new(PathBuf::from("/b"), "m".into()),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Appending to a's history must not affect b's.
|
||||||
|
get(&store, &id("a"))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.history
|
||||||
|
.push(Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: MessageContent::Text {
|
||||||
|
text: "hello".into(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
get(&store, &id("a"))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.history
|
||||||
|
.len(),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
get(&store, &id("b"))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.history
|
||||||
|
.len(),
|
||||||
|
0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
462
crates/helexa-acp/src/store.rs
Normal file
462
crates/helexa-acp/src/store.rs
Normal file
@@ -0,0 +1,462 @@
|
|||||||
|
//! On-disk session persistence for `session/load` support.
|
||||||
|
//!
|
||||||
|
//! Storage layout:
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! $XDG_DATA_HOME/helexa-acp/sessions/{session_id}.json
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! (Fallback to `~/.local/share/helexa-acp/sessions/` when
|
||||||
|
//! `$XDG_DATA_HOME` is unset.) One JSON file per session. Writes
|
||||||
|
//! happen at the end of every `session/prompt` round through
|
||||||
|
//! [`save`], using tempfile-plus-rename so a crash mid-write can't
|
||||||
|
//! corrupt the store. Reads happen on `session/load` via [`load`].
|
||||||
|
//!
|
||||||
|
//! No compaction, no rotation: files accumulate until the user
|
||||||
|
//! cleans them up. That's deliberate — disk is cheap, and the
|
||||||
|
//! resume-on-restart workflow matters more than tidiness. The
|
||||||
|
//! [`SESSIONS_DIRNAME`] subdirectory is created lazily on first
|
||||||
|
//! save so an unprivileged install path never errors at startup.
|
||||||
|
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
use agent_client_protocol::schema::SessionId;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::provider::Message;
|
||||||
|
|
||||||
|
const APP_DIRNAME: &str = "helexa-acp";
|
||||||
|
const SESSIONS_DIRNAME: &str = "sessions";
|
||||||
|
const PLANS_DIRNAME: &str = "plans";
|
||||||
|
|
||||||
|
/// The shape persisted to disk for one session. Only what we can't
|
||||||
|
/// rebuild from the running config goes in here: the conversation
|
||||||
|
/// history, the mode toggle, the model id, and the cwd-at-creation.
|
||||||
|
///
|
||||||
|
/// `created_at` / `updated_at` are seconds-since-epoch — cheap to
|
||||||
|
/// compare, no third-party time crate, and stable across runs.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PersistedSession {
|
||||||
|
pub session_id: String,
|
||||||
|
pub cwd: PathBuf,
|
||||||
|
pub model_id: String,
|
||||||
|
pub mode_id: String,
|
||||||
|
pub history: Vec<Message>,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub updated_at: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve the directory that holds session JSON files. Honors
|
||||||
|
/// `$XDG_DATA_HOME`; falls back to `~/.local/share/helexa-acp/sessions/`.
|
||||||
|
/// Returns `None` if neither is resolvable (no `HOME` set — possible
|
||||||
|
/// in stripped-down container environments).
|
||||||
|
pub fn sessions_dir() -> Option<PathBuf> {
|
||||||
|
let base = std::env::var("XDG_DATA_HOME")
|
||||||
|
.ok()
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.map(PathBuf::from)
|
||||||
|
.or_else(|| {
|
||||||
|
std::env::var("HOME")
|
||||||
|
.ok()
|
||||||
|
.map(|h| PathBuf::from(h).join(".local").join("share"))
|
||||||
|
})?;
|
||||||
|
Some(base.join(APP_DIRNAME).join(SESSIONS_DIRNAME))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Atomic save into the default sessions directory.
|
||||||
|
pub fn save(session: &PersistedSession) -> anyhow::Result<()> {
|
||||||
|
let dir = sessions_dir()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
|
||||||
|
save_to_dir(&dir, session)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load from the default sessions directory.
|
||||||
|
pub fn load(session_id: &SessionId) -> anyhow::Result<PersistedSession> {
|
||||||
|
let dir = sessions_dir()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
|
||||||
|
load_from_dir(&dir, session_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Atomic save into an explicit directory. Writes to
|
||||||
|
/// `{id}.json.tmp` then renames over `{id}.json`. Creates the
|
||||||
|
/// target directory if it doesn't exist. Split from [`save`] so
|
||||||
|
/// unit tests can target a per-test scratch dir without mutating
|
||||||
|
/// process-global env vars.
|
||||||
|
pub fn save_to_dir(dir: &std::path::Path, session: &PersistedSession) -> anyhow::Result<()> {
|
||||||
|
std::fs::create_dir_all(dir).map_err(|e| anyhow::anyhow!("create {}: {e}", dir.display()))?;
|
||||||
|
let safe = sanitize_id(&session.session_id);
|
||||||
|
let final_path = dir.join(format!("{safe}.json"));
|
||||||
|
let tmp_path = dir.join(format!("{safe}.json.tmp"));
|
||||||
|
let json = serde_json::to_string_pretty(session)?;
|
||||||
|
std::fs::write(&tmp_path, json)
|
||||||
|
.map_err(|e| anyhow::anyhow!("write {}: {e}", tmp_path.display()))?;
|
||||||
|
std::fs::rename(&tmp_path, &final_path)
|
||||||
|
.map_err(|e| anyhow::anyhow!("rename → {}: {e}", final_path.display()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load from an explicit directory. Returns a friendly error
|
||||||
|
/// message when the session id has no file on disk so the caller
|
||||||
|
/// can map it to a clean ACP error response.
|
||||||
|
pub fn load_from_dir(
|
||||||
|
dir: &std::path::Path,
|
||||||
|
session_id: &SessionId,
|
||||||
|
) -> anyhow::Result<PersistedSession> {
|
||||||
|
let safe = sanitize_id(session_id.0.as_ref());
|
||||||
|
let path = dir.join(format!("{safe}.json"));
|
||||||
|
let bytes = std::fs::read(&path).map_err(|e| {
|
||||||
|
if e.kind() == std::io::ErrorKind::NotFound {
|
||||||
|
anyhow::anyhow!("no persisted session at {}", path.display())
|
||||||
|
} else {
|
||||||
|
anyhow::anyhow!("read {}: {e}", path.display())
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
let session: PersistedSession = serde_json::from_slice(&bytes)
|
||||||
|
.map_err(|e| anyhow::anyhow!("parse {}: {e}", path.display()))?;
|
||||||
|
Ok(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all persisted sessions, optionally filtered by `cwd`. Used
|
||||||
|
/// by the `session/list` handler so a client (Zed) can find the
|
||||||
|
/// session that belongs to the workspace it's reopening.
|
||||||
|
///
|
||||||
|
/// `filter_cwd = None` returns every session on disk. `Some(path)`
|
||||||
|
/// returns only sessions whose persisted `cwd` is exactly equal.
|
||||||
|
///
|
||||||
|
/// Files that fail to parse are skipped with a warning rather than
|
||||||
|
/// aborting the whole list — one corrupt session shouldn't make
|
||||||
|
/// the resume picker unusable.
|
||||||
|
pub fn list(filter_cwd: Option<&std::path::Path>) -> anyhow::Result<Vec<PersistedSession>> {
|
||||||
|
let dir = sessions_dir()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
|
||||||
|
list_in_dir(&dir, filter_cwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Explicit-dir variant for tests, mirroring [`save_to_dir`] /
|
||||||
|
/// [`load_from_dir`].
|
||||||
|
pub fn list_in_dir(
|
||||||
|
dir: &std::path::Path,
|
||||||
|
filter_cwd: Option<&std::path::Path>,
|
||||||
|
) -> anyhow::Result<Vec<PersistedSession>> {
|
||||||
|
let read = match std::fs::read_dir(dir) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
|
||||||
|
Err(e) => return Err(anyhow::anyhow!("read_dir {}: {e}", dir.display())),
|
||||||
|
};
|
||||||
|
let mut out = Vec::new();
|
||||||
|
for entry in read.flatten() {
|
||||||
|
let path = entry.path();
|
||||||
|
if path.extension().and_then(|s| s.to_str()) != Some("json") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match std::fs::read(&path).and_then(|bytes| {
|
||||||
|
serde_json::from_slice::<PersistedSession>(&bytes).map_err(std::io::Error::other)
|
||||||
|
}) {
|
||||||
|
Ok(session) => {
|
||||||
|
if let Some(want) = filter_cwd
|
||||||
|
&& session.cwd != want
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
out.push(session);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
path = %path.display(),
|
||||||
|
error = %e,
|
||||||
|
"store: skipping unparseable session file"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Most-recent first by updated_at.
|
||||||
|
out.sort_by_key(|s| std::cmp::Reverse(s.updated_at));
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Seconds-since-epoch, saturating to 0 if the system clock is
|
||||||
|
/// behind epoch (which shouldn't happen but the type system
|
||||||
|
/// requires a fallible read).
|
||||||
|
pub fn now_secs() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_secs())
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Root directory for plan-mode artefacts. Mirrors [`sessions_dir`]
|
||||||
|
/// but under `…/helexa-acp/plans/` so plans and conversation
|
||||||
|
/// transcripts are siblings, not nested.
|
||||||
|
pub fn plans_root() -> Option<PathBuf> {
|
||||||
|
sessions_dir().and_then(|s| s.parent().map(|p| p.join(PLANS_DIRNAME)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-project plan directory:
|
||||||
|
/// `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. The id derives
|
||||||
|
/// from the session's cwd so plans for the same project survive
|
||||||
|
/// across cwd-changes (a `/home/foo/git/bar` ↔ symlinked
|
||||||
|
/// `/srv/checkout/bar` would technically diverge, accepted as a
|
||||||
|
/// won't-fix corner case).
|
||||||
|
pub fn plan_dir_for(cwd: &std::path::Path) -> Option<PathBuf> {
|
||||||
|
plans_root().map(|root| root.join(project_id_for(cwd)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deterministic, human-readable project identifier. Format:
|
||||||
|
/// `<basename>-<8-hex>` where the 8-hex suffix is FNV-1a of the
|
||||||
|
/// full path. Basename keeps the path skim-readable when poking
|
||||||
|
/// around `$XDG_DATA_HOME` by hand; the hash suffix disambiguates
|
||||||
|
/// repos that share a final path component (e.g. multiple
|
||||||
|
/// `/.../checkout/beat` checkouts).
|
||||||
|
///
|
||||||
|
/// FNV-1a rather than `std::collections::hash::DefaultHasher`
|
||||||
|
/// because the latter (SipHash) reseeds per process, so it'd give
|
||||||
|
/// us a different project_id on every run.
|
||||||
|
pub fn project_id_for(cwd: &std::path::Path) -> String {
|
||||||
|
let basename = cwd
|
||||||
|
.file_name()
|
||||||
|
.and_then(|s| s.to_str())
|
||||||
|
.unwrap_or("unknown");
|
||||||
|
let sanitised: String = basename
|
||||||
|
.chars()
|
||||||
|
.map(|c| {
|
||||||
|
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
'_'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let hash = fnv1a_32(cwd.to_string_lossy().as_bytes());
|
||||||
|
format!("{sanitised}-{hash:08x}")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// FNV-1a (32-bit). Deterministic, no third-party crate. Used for
|
||||||
|
/// project ids only — not cryptographic.
|
||||||
|
fn fnv1a_32(bytes: &[u8]) -> u32 {
|
||||||
|
let mut h: u32 = 0x811c_9dc5;
|
||||||
|
for b in bytes {
|
||||||
|
h ^= u32::from(*b);
|
||||||
|
h = h.wrapping_mul(0x0100_0193);
|
||||||
|
}
|
||||||
|
h
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Format seconds-since-epoch as an ISO 8601 / RFC 3339 string
|
||||||
|
/// (`YYYY-MM-DDTHH:MM:SSZ`) for `SessionInfo.updated_at`. Returns
|
||||||
|
/// `None` for values outside the representable range, in which
|
||||||
|
/// case the caller should omit the field.
|
||||||
|
pub fn unix_to_iso8601(secs: u64) -> Option<String> {
|
||||||
|
use chrono::TimeZone;
|
||||||
|
let dt = chrono::Utc.timestamp_opt(secs as i64, 0).single()?;
|
||||||
|
Some(dt.to_rfc3339_opts(chrono::SecondsFormat::Secs, true))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Strip anything that isn't a safe filename character so a
|
||||||
|
/// mischievous (or just unconventional) session id can't escape
|
||||||
|
/// the sessions directory.
|
||||||
|
fn sanitize_id(id: &str) -> String {
|
||||||
|
id.chars()
|
||||||
|
.map(|c| {
|
||||||
|
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
'_'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::provider::{MessageContent, Role};
|
||||||
|
|
||||||
|
/// Unique scratch dir per test invocation. We use this dir
|
||||||
|
/// directly with the `*_to_dir` / `*_from_dir` functions so
|
||||||
|
/// the tests never mutate `$XDG_DATA_HOME` — that env var
|
||||||
|
/// would race across the parallel test harness.
|
||||||
|
fn unique_dir() -> PathBuf {
|
||||||
|
let base = std::env::var("CARGO_TARGET_TMPDIR")
|
||||||
|
.ok()
|
||||||
|
.map(PathBuf::from)
|
||||||
|
.unwrap_or_else(std::env::temp_dir);
|
||||||
|
let pid = std::process::id();
|
||||||
|
let nanos = SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.map(|d| d.subsec_nanos())
|
||||||
|
.unwrap_or(0);
|
||||||
|
let dir = base.join(format!("helexa-acp-store-test-{pid}-{nanos}"));
|
||||||
|
std::fs::create_dir_all(&dir).expect("create test dir");
|
||||||
|
dir
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample(id: &str) -> PersistedSession {
|
||||||
|
PersistedSession {
|
||||||
|
session_id: id.into(),
|
||||||
|
cwd: PathBuf::from("/home/me/proj"),
|
||||||
|
model_id: "Qwen/Qwen3.6-27B".into(),
|
||||||
|
mode_id: "default".into(),
|
||||||
|
history: vec![
|
||||||
|
Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: MessageContent::Text {
|
||||||
|
text: "hello".into(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: MessageContent::Text { text: "hi".into() },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created_at: 1_700_000_000,
|
||||||
|
updated_at: 1_700_000_001,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn round_trip_save_then_load() {
|
||||||
|
let dir = unique_dir();
|
||||||
|
save_to_dir(&dir, &sample("hxa-1")).expect("save");
|
||||||
|
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
|
||||||
|
assert_eq!(loaded.session_id, "hxa-1");
|
||||||
|
assert_eq!(loaded.cwd, PathBuf::from("/home/me/proj"));
|
||||||
|
assert_eq!(loaded.history.len(), 2);
|
||||||
|
let _ = std::fs::remove_dir_all(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_missing_session_errors_with_not_found_message() {
|
||||||
|
let dir = unique_dir();
|
||||||
|
let err = load_from_dir(&dir, &SessionId::new("nope")).unwrap_err();
|
||||||
|
let msg = format!("{err}");
|
||||||
|
assert!(
|
||||||
|
msg.contains("no persisted session"),
|
||||||
|
"want NotFound, got: {msg}"
|
||||||
|
);
|
||||||
|
let _ = std::fs::remove_dir_all(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn save_overwrites_existing_atomically() {
|
||||||
|
let dir = unique_dir();
|
||||||
|
save_to_dir(&dir, &sample("hxa-1")).expect("save");
|
||||||
|
let mut updated = sample("hxa-1");
|
||||||
|
updated.history.push(Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: MessageContent::Text {
|
||||||
|
text: "third turn".into(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
updated.updated_at = 1_700_000_500;
|
||||||
|
save_to_dir(&dir, &updated).expect("re-save");
|
||||||
|
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
|
||||||
|
assert_eq!(loaded.history.len(), 3);
|
||||||
|
assert_eq!(loaded.updated_at, 1_700_000_500);
|
||||||
|
let _ = std::fs::remove_dir_all(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn save_then_load_preserves_tool_calls_and_results() {
|
||||||
|
use crate::provider::ToolCall;
|
||||||
|
let dir = unique_dir();
|
||||||
|
let mut session = sample("hxa-2");
|
||||||
|
session.history.push(Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content: MessageContent::ToolCalls {
|
||||||
|
text: Some("calling".into()),
|
||||||
|
calls: vec![ToolCall {
|
||||||
|
id: "call_0".into(),
|
||||||
|
name: "read_file".into(),
|
||||||
|
arguments: r#"{"path":"/etc/hostname"}"#.into(),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
session.history.push(Message {
|
||||||
|
role: Role::Tool,
|
||||||
|
content: MessageContent::ToolResult {
|
||||||
|
tool_call_id: "call_0".into(),
|
||||||
|
content: "host".into(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
save_to_dir(&dir, &session).expect("save");
|
||||||
|
let loaded = load_from_dir(&dir, &SessionId::new("hxa-2")).expect("load");
|
||||||
|
assert_eq!(loaded.history.len(), 4);
|
||||||
|
match &loaded.history[2].content {
|
||||||
|
MessageContent::ToolCalls { calls, .. } => {
|
||||||
|
assert_eq!(calls[0].name, "read_file");
|
||||||
|
}
|
||||||
|
other => panic!("expected ToolCalls, got {other:?}"),
|
||||||
|
}
|
||||||
|
let _ = std::fs::remove_dir_all(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn list_filters_by_cwd_and_sorts_recent_first() {
|
||||||
|
let dir = unique_dir();
|
||||||
|
let mut a = sample("a");
|
||||||
|
a.cwd = PathBuf::from("/home/me/proj-x");
|
||||||
|
a.updated_at = 1_700_000_010;
|
||||||
|
let mut b = sample("b");
|
||||||
|
b.cwd = PathBuf::from("/home/me/proj-x");
|
||||||
|
b.updated_at = 1_700_000_020;
|
||||||
|
let mut c = sample("c");
|
||||||
|
c.cwd = PathBuf::from("/home/me/elsewhere");
|
||||||
|
c.updated_at = 1_700_000_030;
|
||||||
|
save_to_dir(&dir, &a).unwrap();
|
||||||
|
save_to_dir(&dir, &b).unwrap();
|
||||||
|
save_to_dir(&dir, &c).unwrap();
|
||||||
|
|
||||||
|
let proj_x = PathBuf::from("/home/me/proj-x");
|
||||||
|
let list = list_in_dir(&dir, Some(&proj_x)).unwrap();
|
||||||
|
let ids: Vec<&str> = list.iter().map(|s| s.session_id.as_str()).collect();
|
||||||
|
// Filtered to proj-x; b before a because b is more recent.
|
||||||
|
assert_eq!(ids, vec!["b", "a"]);
|
||||||
|
|
||||||
|
let all = list_in_dir(&dir, None).unwrap();
|
||||||
|
assert_eq!(all.len(), 3);
|
||||||
|
// Global list still sorted recent-first across all cwds.
|
||||||
|
assert_eq!(all[0].session_id, "c");
|
||||||
|
|
||||||
|
let _ = std::fs::remove_dir_all(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn list_returns_empty_for_missing_dir() {
|
||||||
|
let dir = unique_dir().join("does-not-exist");
|
||||||
|
let list = list_in_dir(&dir, None).unwrap();
|
||||||
|
assert!(list.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn list_skips_unparseable_files() {
|
||||||
|
let dir = unique_dir();
|
||||||
|
save_to_dir(&dir, &sample("good")).unwrap();
|
||||||
|
std::fs::write(dir.join("garbage.json"), b"{not valid json").unwrap();
|
||||||
|
let list = list_in_dir(&dir, None).unwrap();
|
||||||
|
// Garbage skipped; good survives.
|
||||||
|
assert_eq!(list.len(), 1);
|
||||||
|
assert_eq!(list[0].session_id, "good");
|
||||||
|
let _ = std::fs::remove_dir_all(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn iso8601_formats_unix_seconds() {
|
||||||
|
// 2024-01-01T00:00:00Z is 1704067200 unix seconds.
|
||||||
|
assert_eq!(
|
||||||
|
unix_to_iso8601(1_704_067_200),
|
||||||
|
Some("2024-01-01T00:00:00Z".into())
|
||||||
|
);
|
||||||
|
assert_eq!(unix_to_iso8601(0), Some("1970-01-01T00:00:00Z".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sanitize_id_rejects_path_traversal() {
|
||||||
|
// `../../etc/passwd` — 6 non-alnum chars before "etc"
|
||||||
|
// (`.`, `.`, `/`, `.`, `.`, `/`), one between, none
|
||||||
|
// after, none before nothing. Every disallowed char
|
||||||
|
// collapses to `_`.
|
||||||
|
assert_eq!(sanitize_id("../../etc/passwd"), "______etc_passwd");
|
||||||
|
assert_eq!(sanitize_id("ok-name_42"), "ok-name_42");
|
||||||
|
}
|
||||||
|
}
|
||||||
1469
crates/helexa-acp/src/tool_runner.rs
Normal file
1469
crates/helexa-acp/src/tool_runner.rs
Normal file
File diff suppressed because it is too large
Load Diff
300
crates/helexa-acp/src/tools.rs
Normal file
300
crates/helexa-acp/src/tools.rs
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
//! Tool schemas sent to the upstream model on every completion.
|
||||||
|
//!
|
||||||
|
//! These are the OpenAI-function-style declarations the LLM sees in
|
||||||
|
//! `CompletionRequest.tools`; the runtime dispatch happens in
|
||||||
|
//! [`crate::tool_runner`]. Keeping declarations and execution in
|
||||||
|
//! separate modules makes it easy to add a tool without touching the
|
||||||
|
//! runner, and vice versa.
|
||||||
|
//!
|
||||||
|
//! Stage 3 ships five: filesystem read / write / edit, directory
|
||||||
|
//! listing, and `bash`. Image generation, web fetch, MCP-derived
|
||||||
|
//! tools, etc. are out of scope here.
|
||||||
|
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::provider::ToolSpec;
|
||||||
|
|
||||||
|
pub const READ_FILE: &str = "read_file";
|
||||||
|
pub const WRITE_FILE: &str = "write_file";
|
||||||
|
pub const EDIT_FILE: &str = "edit_file";
|
||||||
|
pub const LIST_DIR: &str = "list_dir";
|
||||||
|
pub const BASH: &str = "bash";
|
||||||
|
|
||||||
|
/// Build the static tool list passed to the model on every prompt.
|
||||||
|
/// Cheap — the JSON Schema fragments are constructed each call but
|
||||||
|
/// the bodies are small constants. If this ever shows up in a
|
||||||
|
/// profile we can `OnceLock` the Vec.
|
||||||
|
pub fn all_tools() -> Vec<ToolSpec> {
|
||||||
|
vec![
|
||||||
|
ToolSpec {
|
||||||
|
name: READ_FILE.to_string(),
|
||||||
|
description: "Read the contents of a text file. Returns the file's text.".to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Absolute path to the file."
|
||||||
|
},
|
||||||
|
"line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Optional 1-based line number to start reading from.",
|
||||||
|
"minimum": 1
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Optional maximum number of lines to read.",
|
||||||
|
"minimum": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
"additionalProperties": false
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
ToolSpec {
|
||||||
|
name: WRITE_FILE.to_string(),
|
||||||
|
description: "Write text content to a file, replacing any existing contents. \
|
||||||
|
Creates the file (and parent directories) if needed."
|
||||||
|
.to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Absolute path to the file."
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Full new contents of the file."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "content"],
|
||||||
|
"additionalProperties": false
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
ToolSpec {
|
||||||
|
name: EDIT_FILE.to_string(),
|
||||||
|
description: "Replace one exact substring in a file with another. \
|
||||||
|
Fails if `old_text` does not appear in the file, or appears more than once. \
|
||||||
|
Use multiple edit_file calls for multiple edits."
|
||||||
|
.to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Absolute path to the file."
|
||||||
|
},
|
||||||
|
"old_text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Exact text fragment to replace. Must be unique within the file."
|
||||||
|
},
|
||||||
|
"new_text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Replacement text."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "old_text", "new_text"],
|
||||||
|
"additionalProperties": false
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
ToolSpec {
|
||||||
|
name: LIST_DIR.to_string(),
|
||||||
|
description:
|
||||||
|
"List the entries of a directory. Returns names and a (f|d|l) kind per entry."
|
||||||
|
.to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Absolute path to the directory."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
"additionalProperties": false
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
ToolSpec {
|
||||||
|
name: BASH.to_string(),
|
||||||
|
description: "Run a shell command via `sh -c`. \
|
||||||
|
Returns combined stdout+stderr and the exit status. \
|
||||||
|
The command runs in the session's working directory unless `cwd` is given."
|
||||||
|
.to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Shell command line, evaluated by `sh -c`."
|
||||||
|
},
|
||||||
|
"cwd": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional absolute path to run the command from."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
"additionalProperties": false
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to infer which tool was intended from the shape of an
|
||||||
|
/// `arguments` object alone. Used by the agent when the model
|
||||||
|
/// emits a `<tool_call>` whose JSON has the right arguments but a
|
||||||
|
/// missing or invalid top-level `name` field — a recurring
|
||||||
|
/// Qwen3.6-27B failure mode.
|
||||||
|
///
|
||||||
|
/// Returns `Some(name)` only when the argument keys uniquely match
|
||||||
|
/// exactly one tool in the catalogue. Ambiguous shapes (`{path}`
|
||||||
|
/// alone could be either [`READ_FILE`] or [`LIST_DIR`]) return
|
||||||
|
/// `None` so the caller surfaces a Failed-card and lets the model
|
||||||
|
/// retry rather than guessing wrong.
|
||||||
|
///
|
||||||
|
/// Inference table (key set → tool):
|
||||||
|
///
|
||||||
|
/// | Keys | Tool |
|
||||||
|
/// |---------------------------------------|--------------|
|
||||||
|
/// | `{command}` or `{command, cwd}` | `bash` |
|
||||||
|
/// | `{path, content}` | `write_file` |
|
||||||
|
/// | `{path, old_text, new_text}` | `edit_file` |
|
||||||
|
/// | `{path}` / `{path, line}` / `{path, line, limit}` | *ambiguous* — None |
|
||||||
|
/// | (anything else) | None |
|
||||||
|
pub fn infer_tool_name(arguments: &serde_json::Value) -> Option<&'static str> {
|
||||||
|
let obj = arguments.as_object()?;
|
||||||
|
let keys: std::collections::HashSet<&str> = obj.keys().map(|s| s.as_str()).collect();
|
||||||
|
|
||||||
|
// `command` is unique to bash. Allow the optional `cwd` arg
|
||||||
|
// alongside but nothing else (any unrecognised keys → bail and
|
||||||
|
// let the model retry rather than misroute).
|
||||||
|
if keys.contains("command") && keys.iter().all(|k| matches!(*k, "command" | "cwd")) {
|
||||||
|
return Some(BASH);
|
||||||
|
}
|
||||||
|
// `content` is unique to write_file.
|
||||||
|
if keys.contains("content") && keys.contains("path") && keys.len() == 2 {
|
||||||
|
return Some(WRITE_FILE);
|
||||||
|
}
|
||||||
|
// `old_text` + `new_text` are unique to edit_file.
|
||||||
|
if keys.contains("old_text")
|
||||||
|
&& keys.contains("new_text")
|
||||||
|
&& keys.contains("path")
|
||||||
|
&& keys.len() == 3
|
||||||
|
{
|
||||||
|
return Some(EDIT_FILE);
|
||||||
|
}
|
||||||
|
// `{path}` / `{path, line}` / `{path, line, limit}` overlap
|
||||||
|
// between read_file (file contents) and list_dir (directory
|
||||||
|
// contents). No safe inference — refuse.
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn all_tools_has_five_named_entries() {
|
||||||
|
let tools = all_tools();
|
||||||
|
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
names,
|
||||||
|
vec![READ_FILE, WRITE_FILE, EDIT_FILE, LIST_DIR, BASH]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infer_bash_from_command_only() {
|
||||||
|
let args = serde_json::json!({"command": "ls /tmp"});
|
||||||
|
assert_eq!(infer_tool_name(&args), Some(BASH));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infer_bash_from_command_and_cwd() {
|
||||||
|
let args = serde_json::json!({"command": "ls", "cwd": "/tmp"});
|
||||||
|
assert_eq!(infer_tool_name(&args), Some(BASH));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infer_bash_from_mkdir_like_real_failure() {
|
||||||
|
// Lifted verbatim from the agent failure that motivated
|
||||||
|
// this helper (helexa-acp.log @ 10:03:11).
|
||||||
|
let args = serde_json::json!({
|
||||||
|
"command": "mkdir -p /home/grenade/git/beat/beat/doc/plan/{01-discovery,02-segmentation,03-description,04-summary,05-output}"
|
||||||
|
});
|
||||||
|
assert_eq!(infer_tool_name(&args), Some(BASH));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infer_write_file() {
|
||||||
|
let args = serde_json::json!({"path": "/tmp/x", "content": "hi"});
|
||||||
|
assert_eq!(infer_tool_name(&args), Some(WRITE_FILE));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infer_edit_file() {
|
||||||
|
let args = serde_json::json!({
|
||||||
|
"path": "/tmp/x", "old_text": "a", "new_text": "b"
|
||||||
|
});
|
||||||
|
assert_eq!(infer_tool_name(&args), Some(EDIT_FILE));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn refuse_ambiguous_path_only() {
|
||||||
|
let args = serde_json::json!({"path": "/tmp/x"});
|
||||||
|
assert_eq!(infer_tool_name(&args), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn refuse_ambiguous_path_with_optionals() {
|
||||||
|
// read_file accepts these optionals; list_dir doesn't —
|
||||||
|
// but Qwen wouldn't reliably emit them either, so we
|
||||||
|
// can't use their presence to disambiguate. Refuse.
|
||||||
|
let args = serde_json::json!({"path": "/tmp/x", "line": 1, "limit": 50});
|
||||||
|
assert_eq!(infer_tool_name(&args), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn refuse_command_with_extra_unknown_keys() {
|
||||||
|
// Defence in depth: an unrecognised key alongside
|
||||||
|
// `command` means we don't really know what tool the
|
||||||
|
// model wanted; refuse rather than guess.
|
||||||
|
let args = serde_json::json!({"command": "ls", "extra": "?"});
|
||||||
|
assert_eq!(infer_tool_name(&args), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn refuse_empty_args() {
|
||||||
|
let args = serde_json::json!({});
|
||||||
|
assert_eq!(infer_tool_name(&args), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn refuse_non_object_args() {
|
||||||
|
let args = serde_json::json!("not an object");
|
||||||
|
assert_eq!(infer_tool_name(&args), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn every_tool_has_an_object_parameter_schema() {
|
||||||
|
for tool in all_tools() {
|
||||||
|
let ty = tool.parameters.get("type").and_then(|v| v.as_str());
|
||||||
|
assert_eq!(
|
||||||
|
ty,
|
||||||
|
Some("object"),
|
||||||
|
"tool {} parameters.type must be \"object\"",
|
||||||
|
tool.name
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
tool.parameters.get("properties").is_some(),
|
||||||
|
"tool {} missing properties",
|
||||||
|
tool.name
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
tool.parameters.get("required").is_some(),
|
||||||
|
"tool {} missing required list",
|
||||||
|
tool.name
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
107
crates/neuron/Cargo.toml
Normal file
107
crates/neuron/Cargo.toml
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
[package]
|
||||||
|
name = "neuron"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "neuron"
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "neuron"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
# Enables CUDA acceleration in candle and the cudarc/nccl bindings the
|
||||||
|
# TP worker pool uses. Without this feature, candle compiles for CPU
|
||||||
|
# only, Device::new_cuda calls fall back to CPU, and TP Init/sanity
|
||||||
|
# requests return Error{kind="cuda_feature_not_enabled"}.
|
||||||
|
cuda = [
|
||||||
|
"candle-core/cuda",
|
||||||
|
"candle-core/nccl",
|
||||||
|
"candle-nn/cuda",
|
||||||
|
"candle-transformers/cuda",
|
||||||
|
"dep:cudarc",
|
||||||
|
"dep:half",
|
||||||
|
"dep:cudaforge",
|
||||||
|
]
|
||||||
|
# Use cuDNN for convolution / attention kernels. Requires CUDA.
|
||||||
|
cudnn = [
|
||||||
|
"cuda",
|
||||||
|
"candle-core/cudnn",
|
||||||
|
"candle-nn/cudnn",
|
||||||
|
"candle-transformers/cudnn",
|
||||||
|
]
|
||||||
|
# FlashAttention kernels. Requires CUDA.
|
||||||
|
flash-attn = [
|
||||||
|
"cuda",
|
||||||
|
"candle-transformers/flash-attn",
|
||||||
|
]
|
||||||
|
# Reserved for GPU-only integration tests in later stages.
|
||||||
|
cuda-integration = ["cuda"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
cortex-core.workspace = true
|
||||||
|
tokio.workspace = true
|
||||||
|
axum.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
reqwest.workspace = true
|
||||||
|
tracing.workspace = true
|
||||||
|
tracing-subscriber.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
|
clap.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
tokio-stream.workspace = true
|
||||||
|
figment.workspace = true
|
||||||
|
toml.workspace = true
|
||||||
|
|
||||||
|
# candle for in-process inference. CUDA support is gated behind the
|
||||||
|
# crate's `cuda` feature (default off) so the workspace builds on
|
||||||
|
# non-CUDA hosts and CI runners.
|
||||||
|
candle-core = "0.10.2"
|
||||||
|
candle-nn = "0.10.2"
|
||||||
|
candle-transformers = "0.10.2"
|
||||||
|
# Direct dep on cudarc (matching candle's transitive version) so the
|
||||||
|
# TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on
|
||||||
|
# the `cuda` feature; same toolchain requirement as candle's CUDA path.
|
||||||
|
cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] }
|
||||||
|
# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle
|
||||||
|
# storages. Matches candle-core's pinned major version to avoid double-
|
||||||
|
# compiling the `half` crate at conflicting versions.
|
||||||
|
half = { version = "2.5", optional = true }
|
||||||
|
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
|
||||||
|
hf-hub = { version = "0.4", features = ["tokio"] }
|
||||||
|
# Jinja-compatible template renderer for the model's
|
||||||
|
# `tokenizer_config.json::chat_template`. Hugging Face's chat
|
||||||
|
# templates use a strict subset of Jinja2 that minijinja supports
|
||||||
|
# out of the box. ~80KB compiled; pure Rust, no async surface.
|
||||||
|
# Features: `builtins` for the `is defined` / `default` filters HF
|
||||||
|
# templates use; `json` for `tojson` (some Qwen3 templates emit
|
||||||
|
# tool definitions via tojson); `serde` so we can hand it a
|
||||||
|
# serde_json::Value as the context.
|
||||||
|
minijinja = { version = "2", features = ["builtins", "json", "serde"] }
|
||||||
|
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
|
||||||
|
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
|
||||||
|
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
||||||
|
# without materialising the full tensor on device.
|
||||||
|
safetensors = "0.7"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
|
reqwest.workspace = true
|
||||||
|
tempfile = "3"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a`
|
||||||
|
# under the `cuda` feature. Matches mistralrs's upstream build setup
|
||||||
|
# (their `mistralrs-core/build.rs` uses the same constructor).
|
||||||
|
cudaforge = { version = "0.1", optional = true }
|
||||||
|
|
||||||
|
[package.metadata.docs.rs]
|
||||||
|
# Skip the CUDA path on docs.rs (it lacks nvcc).
|
||||||
|
no-default-features = true
|
||||||
66
crates/neuron/build.rs
Normal file
66
crates/neuron/build.rs
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//! Build script: compile the CUDA kernels in `src/cuda/*.cu` into a
|
||||||
|
//! static library and link it under the `cuda` feature.
|
||||||
|
//!
|
||||||
|
//! Patterned on `EricLBuehler/mistral.rs::mistralrs-core/build.rs` —
|
||||||
|
//! same `cudaforge::KernelBuilder` invocation, same NVCC flag set.
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
use std::path::PathBuf;
|
||||||
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
println!("cargo:rerun-if-changed=src/cuda/");
|
||||||
|
|
||||||
|
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||||
|
|
||||||
|
let mut builder = cudaforge::KernelBuilder::new()
|
||||||
|
.source_glob("src/cuda/*.cu")
|
||||||
|
.out_dir(&build_dir)
|
||||||
|
.arg("-std=c++17")
|
||||||
|
.arg("-O3")
|
||||||
|
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||||
|
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||||
|
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||||
|
.arg("--expt-relaxed-constexpr")
|
||||||
|
.arg("--expt-extended-lambda")
|
||||||
|
.arg("--use_fast_math")
|
||||||
|
.arg("--compiler-options")
|
||||||
|
.arg("-fPIC");
|
||||||
|
|
||||||
|
// sm_<80 doesn't have bf16 intrinsics for WMMA — gate the
|
||||||
|
// bf16-only kernels off in that case. (Mirrors upstream.)
|
||||||
|
if let Some(compute_cap) = builder.get_compute_cap()
|
||||||
|
&& compute_cap < 80
|
||||||
|
{
|
||||||
|
builder = builder.arg("-DNO_BF16_KERNEL");
|
||||||
|
}
|
||||||
|
|
||||||
|
let target = std::env::var("TARGET").unwrap();
|
||||||
|
let out_file = if target.contains("msvc") {
|
||||||
|
build_dir.join("neuroncuda.lib")
|
||||||
|
} else {
|
||||||
|
build_dir.join("libneuroncuda.a")
|
||||||
|
};
|
||||||
|
|
||||||
|
builder
|
||||||
|
.build_lib(out_file)
|
||||||
|
.expect("neuron cuda build failed");
|
||||||
|
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||||
|
println!("cargo:rustc-link-lib=neuroncuda");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
|
|
||||||
|
if target.contains("msvc") {
|
||||||
|
// No extra runtime library needed.
|
||||||
|
} else if target.contains("apple")
|
||||||
|
|| target.contains("freebsd")
|
||||||
|
|| target.contains("openbsd")
|
||||||
|
{
|
||||||
|
println!("cargo:rustc-link-lib=dylib=c++");
|
||||||
|
} else if target.contains("android") {
|
||||||
|
println!("cargo:rustc-link-lib=dylib=c++_shared");
|
||||||
|
} else {
|
||||||
|
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
93
crates/neuron/src/activation.rs
Normal file
93
crates/neuron/src/activation.rs
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
//! Activation-time pre-warm progress tracking.
|
||||||
|
//!
|
||||||
|
//! Wraps the [`ActivationStatus`] snapshot in an async RwLock so the
|
||||||
|
//! background pre-warm task can update it per-model while the
|
||||||
|
//! `/health` handler reads coherent snapshots. The tracker exists
|
||||||
|
//! because `default_models` loading moved from synchronous-before-bind
|
||||||
|
//! to background-after-bind on 2026-05-26: the listener is up
|
||||||
|
//! immediately, but `/health` now needs to tell callers which of the
|
||||||
|
//! configured defaults are still warming.
|
||||||
|
|
||||||
|
use cortex_core::discovery::{ActivationState, ActivationStatus, PreWarmFailure};
|
||||||
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
/// Shared, async-safe handle to the daemon's activation progress.
|
||||||
|
///
|
||||||
|
/// Construct once in `main` with the configured `default_models` so
|
||||||
|
/// the initial `pending` list matches the spec; clone the `Arc` into
|
||||||
|
/// the `NeuronState` for HTTP handlers and into the spawned pre-warm
|
||||||
|
/// task for updates.
|
||||||
|
pub struct ActivationTracker {
|
||||||
|
inner: RwLock<ActivationStatus>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActivationTracker {
|
||||||
|
/// Build a tracker primed with one entry per spec. An empty spec
|
||||||
|
/// list yields a `Ready` tracker — no point reporting PreWarming
|
||||||
|
/// when there's nothing queued.
|
||||||
|
pub fn new(default_models: &[ModelSpec]) -> Self {
|
||||||
|
let pending: Vec<String> = default_models.iter().map(|s| s.model_id.clone()).collect();
|
||||||
|
let state = if pending.is_empty() {
|
||||||
|
ActivationState::Ready
|
||||||
|
} else {
|
||||||
|
ActivationState::PreWarming
|
||||||
|
};
|
||||||
|
Self {
|
||||||
|
inner: RwLock::new(ActivationStatus {
|
||||||
|
state,
|
||||||
|
pending,
|
||||||
|
in_progress: None,
|
||||||
|
completed: vec![],
|
||||||
|
failed: vec![],
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark a model as in-progress: remove it from `pending`, set as
|
||||||
|
/// `in_progress`. Called immediately before `registry.load_model`.
|
||||||
|
pub async fn start_loading(&self, model_id: &str) {
|
||||||
|
let mut s = self.inner.write().await;
|
||||||
|
s.pending.retain(|m| m != model_id);
|
||||||
|
s.in_progress = Some(model_id.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark a model as completed: clear `in_progress` (if it matches),
|
||||||
|
/// append to `completed`.
|
||||||
|
pub async fn complete_loading(&self, model_id: &str) {
|
||||||
|
let mut s = self.inner.write().await;
|
||||||
|
if s.in_progress.as_deref() == Some(model_id) {
|
||||||
|
s.in_progress = None;
|
||||||
|
}
|
||||||
|
s.completed.push(model_id.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark a model as failed: clear `in_progress` (if it matches),
|
||||||
|
/// append a `PreWarmFailure` carrying the rendered error chain.
|
||||||
|
pub async fn fail_loading(&self, model_id: &str, error: &str) {
|
||||||
|
let mut s = self.inner.write().await;
|
||||||
|
if s.in_progress.as_deref() == Some(model_id) {
|
||||||
|
s.in_progress = None;
|
||||||
|
}
|
||||||
|
s.failed.push(PreWarmFailure {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
error: error.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Flip the high-level `state` to `Ready` once the pre-warm task
|
||||||
|
/// is done iterating. Pending should be empty by this point; if a
|
||||||
|
/// caller bails early it's a stuck activation and the operator
|
||||||
|
/// will see entries in `pending` even with `state=ready` — that's
|
||||||
|
/// a useful diagnostic, not an inconsistency to scrub.
|
||||||
|
pub async fn mark_ready(&self) {
|
||||||
|
let mut s = self.inner.write().await;
|
||||||
|
s.state = ActivationState::Ready;
|
||||||
|
s.in_progress = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cheap clone of the current state for the `/health` handler.
|
||||||
|
pub async fn snapshot(&self) -> ActivationStatus {
|
||||||
|
self.inner.read().await.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
483
crates/neuron/src/api.rs
Normal file
483
crates/neuron/src/api.rs
Normal file
@@ -0,0 +1,483 @@
|
|||||||
|
//! HTTP API handlers for the neuron daemon.
|
||||||
|
|
||||||
|
use crate::activation::ActivationTracker;
|
||||||
|
use crate::harness::HarnessRegistry;
|
||||||
|
use crate::harness::candle::{CandleHarness, InferenceError};
|
||||||
|
use crate::harness::preflight::PreflightError;
|
||||||
|
use crate::health::HealthCache;
|
||||||
|
use crate::wire::{openai_chat, openai_responses};
|
||||||
|
use axum::Router;
|
||||||
|
use axum::extract::{Path, State};
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
|
use axum::response::{IntoResponse, Json};
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||||
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use cortex_core::openai::{ChatCompletionRequest, MessageContent};
|
||||||
|
use cortex_core::responses::{ResponsesRequest, ResponsesUsage};
|
||||||
|
use futures::stream::{self, StreamExt};
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
use std::convert::Infallible;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
|
||||||
|
/// Shared state for the neuron HTTP server.
|
||||||
|
pub struct NeuronState {
|
||||||
|
pub discovery: DiscoveryResponse,
|
||||||
|
pub health_cache: Arc<HealthCache>,
|
||||||
|
pub registry: RwLock<HarnessRegistry>,
|
||||||
|
/// Typed handle to the candle harness for inference routes. Cached at
|
||||||
|
/// startup so `/v1/chat/completions` doesn't have to hold the registry
|
||||||
|
/// read lock or perform dyn-Trait dispatch per request.
|
||||||
|
pub candle: Option<Arc<CandleHarness>>,
|
||||||
|
/// Activation-time pre-warm progress. Updated by the background
|
||||||
|
/// `load_default_models` task, read by the `/health` handler.
|
||||||
|
pub activation: Arc<ActivationTracker>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the neuron API router.
|
||||||
|
pub fn neuron_routes() -> Router<Arc<NeuronState>> {
|
||||||
|
Router::new()
|
||||||
|
.route("/discovery", get(discovery_handler))
|
||||||
|
.route("/health", get(health_handler))
|
||||||
|
.route("/models", get(list_models))
|
||||||
|
.route("/models/load", post(load_model))
|
||||||
|
.route("/models/unload", post(unload_model))
|
||||||
|
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
||||||
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
.route("/v1/responses", post(responses))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
||||||
|
Json(state.discovery.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_handler(State(state): State<Arc<NeuronState>>) -> Json<HealthResponse> {
|
||||||
|
// HealthCache owns the uptime + per-device readings; the activation
|
||||||
|
// tracker owns the pre-warm progress. We compose the response here
|
||||||
|
// so the cache stays a thin runtime-state cache and doesn't need to
|
||||||
|
// know about activation lifecycle.
|
||||||
|
let mut snapshot = state.health_cache.snapshot().await;
|
||||||
|
snapshot.activation = state.activation.snapshot().await;
|
||||||
|
Json(snapshot)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse {
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
match registry.list_all_models().await {
|
||||||
|
Ok(models) => Json(json!(models)).into_response(),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_model(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Json(spec): Json<ModelSpec>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
match registry.load_model(&spec).await {
|
||||||
|
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
|
||||||
|
Err(e) => {
|
||||||
|
// If the underlying failure is a structured preflight
|
||||||
|
// rejection, surface it as 422 Unprocessable Entity with
|
||||||
|
// the typed JSON body. The kind/model_id/suggestion/etc.
|
||||||
|
// fields let cortex (and operators reading the response
|
||||||
|
// directly) act on the failure without parsing free text.
|
||||||
|
if let Some(pf) = e.downcast_ref::<PreflightError>() {
|
||||||
|
tracing::warn!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
reason = preflight_kind(pf),
|
||||||
|
detail = %pf,
|
||||||
|
"load_model rejected by preflight"
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(json!({ "error": pf })),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
// Log the full anyhow chain server-side so journalctl shows
|
||||||
|
// the underlying failure (hf-hub timeout, permission denied,
|
||||||
|
// disk full, etc.) without needing to inspect the HTTP
|
||||||
|
// response body separately.
|
||||||
|
tracing::warn!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
error = %format!("{e:#}"),
|
||||||
|
"load_model failed"
|
||||||
|
);
|
||||||
|
(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Short kebab-case tag for a preflight failure, used as a structured
|
||||||
|
/// log field for journalctl-side filtering. Mirrors the same helper in
|
||||||
|
/// `startup.rs`; duplicated to keep the module surfaces independent.
|
||||||
|
fn preflight_kind(err: &PreflightError) -> &'static str {
|
||||||
|
match err {
|
||||||
|
PreflightError::RepoFetchFailed { .. } => "repo_fetch_failed",
|
||||||
|
PreflightError::EmptyRepo { .. } => "empty_repo",
|
||||||
|
PreflightError::TpRequiresSafetensors { .. } => "tp_requires_safetensors",
|
||||||
|
PreflightError::QuantNotFound { .. } => "quant_not_found",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn unload_model(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Json(body): Json<Value>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let model_id = match body.get("model_id").and_then(|v| v.as_str()) {
|
||||||
|
Some(id) => id.to_string(),
|
||||||
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({"error": "missing model_id"})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
match registry.unload_model(&model_id).await {
|
||||||
|
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_endpoint(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Path(model_id): Path<String>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
match registry.inference_endpoint(&model_id).await {
|
||||||
|
Some(url) => Json(json!({"url": url})).into_response(),
|
||||||
|
None => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{}' not loaded", model_id)})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// OpenAI-compatible chat completions. Dispatches to streaming SSE when
|
||||||
|
/// `stream: true` is set on the request; otherwise returns a single
|
||||||
|
/// `ChatCompletionResponse`.
|
||||||
|
async fn chat_completions(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
headers: axum::http::HeaderMap,
|
||||||
|
Json(req): Json<ChatCompletionRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||||
|
return (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({"error": "candle harness not enabled on this neuron"})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Reasoning-content opt-in. Off by default → naïve clients
|
||||||
|
// (Zed's commit-message generator, vanilla OpenAI clients)
|
||||||
|
// never see `<think>` blocks. On when the caller sends
|
||||||
|
// `x-include-thinking: true` (helexa-acp does this so its
|
||||||
|
// own ThinkParser keeps working unchanged).
|
||||||
|
let include_thinking = headers
|
||||||
|
.get("x-include-thinking")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|s| matches!(s.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
|
||||||
|
.unwrap_or(false);
|
||||||
|
let chat_config = openai_chat::ChatProjectionConfig {
|
||||||
|
include_thinking,
|
||||||
|
reasoning_markers: None, // filled in from the loaded model inside candle
|
||||||
|
};
|
||||||
|
|
||||||
|
if req.stream.unwrap_or(false) {
|
||||||
|
match candle.chat_completion_stream_with(req, chat_config).await {
|
||||||
|
Ok(rx) => {
|
||||||
|
// Each chunk → one SSE `data: {json}` line. After the
|
||||||
|
// channel closes, append the OpenAI [DONE] terminator.
|
||||||
|
let body_stream = ReceiverStream::new(rx).map(|chunk| {
|
||||||
|
let body = serde_json::to_string(&chunk).unwrap_or_default();
|
||||||
|
Ok::<_, Infallible>(Event::default().data(body))
|
||||||
|
});
|
||||||
|
let done_stream =
|
||||||
|
stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
|
||||||
|
Sse::new(body_stream.chain(done_stream))
|
||||||
|
.keep_alive(KeepAlive::default())
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
|
||||||
|
"code": "prompt_too_long",
|
||||||
|
"prompt_len": prompt_len,
|
||||||
|
"max": max,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::InsufficientVram {
|
||||||
|
free_mb,
|
||||||
|
required_mb,
|
||||||
|
}) => (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({
|
||||||
|
"error": format!(
|
||||||
|
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||||
|
),
|
||||||
|
"code": "insufficient_vram",
|
||||||
|
"free_mb": free_mb,
|
||||||
|
"required_mb": required_mb,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::Other(e)) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match candle.chat_completion(req).await {
|
||||||
|
Ok(resp) => Json(resp).into_response(),
|
||||||
|
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
|
||||||
|
"code": "prompt_too_long",
|
||||||
|
"prompt_len": prompt_len,
|
||||||
|
"max": max,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::InsufficientVram {
|
||||||
|
free_mb,
|
||||||
|
required_mb,
|
||||||
|
}) => (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({
|
||||||
|
"error": format!(
|
||||||
|
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||||
|
),
|
||||||
|
"code": "insufficient_vram",
|
||||||
|
"free_mb": free_mb,
|
||||||
|
"required_mb": required_mb,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(InferenceError::Other(e)) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// OpenAI Responses API (`POST /v1/responses`). Translates the
|
||||||
|
/// Responses-shaped request into a chat-completions one the candle
|
||||||
|
/// harness already understands, then re-projects the harness's
|
||||||
|
/// event stream into the Responses event family.
|
||||||
|
async fn responses(
|
||||||
|
State(state): State<Arc<NeuronState>>,
|
||||||
|
Json(req): Json<ResponsesRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||||
|
return (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({"error": "candle harness not enabled on this neuron"})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream_requested = req.stream;
|
||||||
|
let model_id = req.model.clone();
|
||||||
|
let response_id = mint_response_id();
|
||||||
|
let message_item_id = mint_message_item_id();
|
||||||
|
|
||||||
|
// Translate Responses → chat completions. The only failure
|
||||||
|
// mode today is `previous_response_id` set, which we reject
|
||||||
|
// with 400 — stateful conversations need a persistence layer
|
||||||
|
// we haven't built.
|
||||||
|
let mut chat_req = match openai_responses::request_to_chat(req) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(openai_responses::TranslateError::ChainedConversationNotSupported) => {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": "previous_response_id is not supported on this neuron",
|
||||||
|
"code": "chained_conversation_not_supported"
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
chat_req.stream = Some(stream_requested);
|
||||||
|
|
||||||
|
if stream_requested {
|
||||||
|
match candle
|
||||||
|
.responses_stream(chat_req, response_id, message_item_id)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(rx) => {
|
||||||
|
// Each ResponseStreamFrame → one SSE event carrying
|
||||||
|
// both an event name and JSON data. The Responses
|
||||||
|
// API doesn't use a `[DONE]` terminator — clients
|
||||||
|
// see the `response.completed` event as the end of
|
||||||
|
// the stream.
|
||||||
|
let body_stream = ReceiverStream::new(rx).map(|frame| {
|
||||||
|
let body = serde_json::to_string(&frame.data).unwrap_or_else(|_| "{}".into());
|
||||||
|
Ok::<_, Infallible>(Event::default().event(frame.event_name).data(body))
|
||||||
|
});
|
||||||
|
Sse::new(body_stream)
|
||||||
|
.keep_alive(KeepAlive::default())
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
Err(e) => inference_error_response(e),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-streaming: drive the existing chat completion path
|
||||||
|
// and translate the result. We don't currently re-tokenise
|
||||||
|
// to compute usage; the harness returns it via the chat
|
||||||
|
// response and we pass it through.
|
||||||
|
match candle.chat_completion(chat_req).await {
|
||||||
|
Ok(chat_resp) => {
|
||||||
|
// Extract the assistant text (chat completions
|
||||||
|
// always emits one choice on the candle path).
|
||||||
|
let text = chat_resp
|
||||||
|
.choices
|
||||||
|
.first()
|
||||||
|
.map(|c| match &c.message.content {
|
||||||
|
MessageContent::Text(t) => t.clone(),
|
||||||
|
MessageContent::Parts(_) => {
|
||||||
|
// Candle output is always text today;
|
||||||
|
// a Parts response would be surprising.
|
||||||
|
// Empty-string fallback is safer than
|
||||||
|
// a panic.
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
let finish = chat_resp
|
||||||
|
.choices
|
||||||
|
.first()
|
||||||
|
.and_then(|c| c.finish_reason.as_deref())
|
||||||
|
.map(finish_reason_from_str)
|
||||||
|
.unwrap_or(crate::wire::FinishReason::Stop);
|
||||||
|
let usage = chat_resp.usage.as_ref().map(|u| ResponsesUsage {
|
||||||
|
input_tokens: u.prompt_tokens,
|
||||||
|
output_tokens: u.completion_tokens,
|
||||||
|
total_tokens: u.prompt_tokens + u.completion_tokens,
|
||||||
|
});
|
||||||
|
let meta = openai_responses::ResponseMeta {
|
||||||
|
response_id: mint_response_id(),
|
||||||
|
created_at: unix_now_secs(),
|
||||||
|
model_id,
|
||||||
|
message_item_id: mint_message_item_id(),
|
||||||
|
};
|
||||||
|
let _ = chat_resp; // make the borrow-checker happy if `text` consumed it
|
||||||
|
let resp = openai_responses::build_response(&meta, text, finish, usage);
|
||||||
|
Json(resp).into_response()
|
||||||
|
}
|
||||||
|
Err(e) => inference_error_response(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finish_reason_from_str(s: &str) -> crate::wire::FinishReason {
|
||||||
|
use crate::wire::FinishReason;
|
||||||
|
match s {
|
||||||
|
"length" => FinishReason::Length,
|
||||||
|
"tool_calls" => FinishReason::ToolCalls,
|
||||||
|
_ => FinishReason::Stop,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Centralised mapping from [`InferenceError`] to an HTTP response.
|
||||||
|
/// Lifted out so the chat-completions and responses handlers stay
|
||||||
|
/// readable and changes to error-code semantics happen in one spot.
|
||||||
|
fn inference_error_response(err: InferenceError) -> axum::response::Response {
|
||||||
|
match err {
|
||||||
|
InferenceError::ModelNotLoaded(id) => (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
InferenceError::PromptTooLong { prompt_len, max } => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
|
||||||
|
"code": "prompt_too_long",
|
||||||
|
"prompt_len": prompt_len,
|
||||||
|
"max": max,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
InferenceError::InsufficientVram {
|
||||||
|
free_mb,
|
||||||
|
required_mb,
|
||||||
|
} => (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({
|
||||||
|
"error": format!(
|
||||||
|
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||||
|
),
|
||||||
|
"code": "insufficient_vram",
|
||||||
|
"free_mb": free_mb,
|
||||||
|
"required_mb": required_mb,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
InferenceError::Other(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mint_response_id() -> String {
|
||||||
|
format!("resp_{:x}", unix_subsec_nanos())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mint_message_item_id() -> String {
|
||||||
|
format!("msg_{:x}", unix_subsec_nanos())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unix_now_secs() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_secs())
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unix_subsec_nanos() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_nanos() as u64)
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
67
crates/neuron/src/config.rs
Normal file
67
crates/neuron/src/config.rs
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
//! Neuron configuration loaded from neuron.toml.
|
||||||
|
|
||||||
|
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||||
|
use figment::{
|
||||||
|
Figment,
|
||||||
|
providers::{Env, Format, Toml},
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct NeuronConfig {
|
||||||
|
#[serde(default = "default_port")]
|
||||||
|
pub port: u16,
|
||||||
|
#[serde(default)]
|
||||||
|
pub harnesses: Vec<HarnessConfig>,
|
||||||
|
/// Per-harness configuration. Currently only `candle` is recognised.
|
||||||
|
#[serde(default)]
|
||||||
|
pub harness: HarnessSettings,
|
||||||
|
/// Models to auto-load when the neuron service activates. Each entry
|
||||||
|
/// is loaded sequentially before the HTTP listener binds. A failure
|
||||||
|
/// on any single entry logs a warning and proceeds — broken entries
|
||||||
|
/// don't prevent the rest of the fleet from starting.
|
||||||
|
#[serde(default)]
|
||||||
|
pub default_models: Vec<ModelSpec>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Settings for individual harness implementations. Each harness owns
|
||||||
|
/// its own sub-table so users only configure the harnesses they enable.
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
pub struct HarnessSettings {
|
||||||
|
#[serde(default)]
|
||||||
|
pub candle: CandleHarnessConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
pub struct CandleHarnessConfig {
|
||||||
|
/// HuggingFace cache directory for model weights.
|
||||||
|
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
|
||||||
|
#[serde(default)]
|
||||||
|
pub hf_cache: Option<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_port() -> u16 {
|
||||||
|
13131
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NeuronConfig {
|
||||||
|
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<figment::Error>> {
|
||||||
|
Figment::new()
|
||||||
|
.merge(Toml::file(path))
|
||||||
|
.merge(Env::prefixed("NEURON_").split("__"))
|
||||||
|
.extract()
|
||||||
|
.map_err(Box::new)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for NeuronConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
port: 13131,
|
||||||
|
harnesses: vec![],
|
||||||
|
harness: HarnessSettings::default(),
|
||||||
|
default_models: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
84
crates/neuron/src/cuda/ffi.rs
Normal file
84
crates/neuron/src/cuda/ffi.rs
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
//! FFI declarations for the CUDA kernels in `gdn.cu`.
|
||||||
|
//!
|
||||||
|
//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs`
|
||||||
|
//! covering only the Gated DeltaNet kernels we currently use. Other
|
||||||
|
//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective
|
||||||
|
//! scan, etc.) would land here too as we absorb them.
|
||||||
|
//!
|
||||||
|
//! All function declarations are MIT-licensed from upstream and
|
||||||
|
//! unchanged apart from this header.
|
||||||
|
|
||||||
|
use std::ffi::c_void;
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
unsafe extern "C" {
|
||||||
|
// GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next.
|
||||||
|
pub(crate) fn gated_delta_rule_recurrence(
|
||||||
|
q: *const f32,
|
||||||
|
k: *const f32,
|
||||||
|
v: *const f32,
|
||||||
|
g: *const f32,
|
||||||
|
beta: *const f32,
|
||||||
|
state: *mut f32,
|
||||||
|
output: *mut f32,
|
||||||
|
bh: i32,
|
||||||
|
seq_len: i32,
|
||||||
|
k_dim: i32,
|
||||||
|
v_dim: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks).
|
||||||
|
pub(crate) fn chunked_gated_delta_rule_recurrence(
|
||||||
|
q: *const f32,
|
||||||
|
k: *const f32,
|
||||||
|
v: *const f32,
|
||||||
|
g: *const f32,
|
||||||
|
beta: *const f32,
|
||||||
|
state: *mut f32,
|
||||||
|
output: *mut f32,
|
||||||
|
bh: i32,
|
||||||
|
seq_len: i32,
|
||||||
|
k_dim: i32,
|
||||||
|
v_dim: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub(crate) fn causal_conv1d_update(
|
||||||
|
x: *const c_void,
|
||||||
|
weight: *const c_void,
|
||||||
|
conv_state: *mut c_void,
|
||||||
|
output: *mut c_void,
|
||||||
|
batch_size: i32,
|
||||||
|
conv_dim: i32,
|
||||||
|
kernel_size: i32,
|
||||||
|
dtype: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub(crate) fn causal_conv1d_full(
|
||||||
|
x: *const c_void,
|
||||||
|
weight: *const c_void,
|
||||||
|
conv_state_out: *mut c_void,
|
||||||
|
output: *mut c_void,
|
||||||
|
batch_size: i32,
|
||||||
|
conv_dim: i32,
|
||||||
|
seq_len: i32,
|
||||||
|
kernel_size: i32,
|
||||||
|
dtype: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub(crate) fn fused_gdn_gating(
|
||||||
|
b: *const c_void,
|
||||||
|
a: *const c_void,
|
||||||
|
a_log: *const f32,
|
||||||
|
dt_bias: *const f32,
|
||||||
|
beta_out: *mut c_void,
|
||||||
|
g_out: *mut c_void,
|
||||||
|
total_elements: i32,
|
||||||
|
num_heads: i32,
|
||||||
|
dtype: i32,
|
||||||
|
stream: i64,
|
||||||
|
);
|
||||||
|
}
|
||||||
711
crates/neuron/src/cuda/gdn.cu
Normal file
711
crates/neuron/src/cuda/gdn.cu
Normal file
@@ -0,0 +1,711 @@
|
|||||||
|
// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`).
|
||||||
|
//
|
||||||
|
// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||||
|
// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this
|
||||||
|
// file are limited to this banner; the kernels are unchanged so a
|
||||||
|
// diff against upstream stays minimal.
|
||||||
|
//
|
||||||
|
// Five kernels exposed via `extern "C"` shims at the bottom:
|
||||||
|
// - gated_delta_rule_recurrence (per-token decode)
|
||||||
|
// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill)
|
||||||
|
// - causal_conv1d_update (single-token conv decode)
|
||||||
|
// - causal_conv1d_full (multi-token conv prefill)
|
||||||
|
// - fused_gdn_gating (beta = sigmoid(b);
|
||||||
|
// g = -exp(A_log) * softplus(a + dt_bias))
|
||||||
|
|
||||||
|
#include "cuda_bf16.h"
|
||||||
|
#include "cuda_fp16.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 1: gated_delta_rule_recurrence (optimized)
|
||||||
|
//
|
||||||
|
// V-tiled recurrence with compile-time K dimension for register residency.
|
||||||
|
// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of
|
||||||
|
// state. Shared memory holds k_buf and q_buf (2*BK floats).
|
||||||
|
//
|
||||||
|
// Optimizations over naive version:
|
||||||
|
// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30)
|
||||||
|
// - #pragma unroll on all k-loops -> full ILP
|
||||||
|
// - Fused decay+kv_mem pass and fused state_update+output pass
|
||||||
|
// - __fmaf_rn intrinsics for guaranteed fused multiply-add
|
||||||
|
// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere
|
||||||
|
//
|
||||||
|
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||||
|
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// Optimized kernel: BK known at compile time -> registers + full unrolling
|
||||||
|
template <int BK, int BV>
|
||||||
|
__global__ void gated_delta_rule_recurrence_kernel_tiled(
|
||||||
|
const float *__restrict__ q, // [BH, S, K]
|
||||||
|
const float *__restrict__ k, // [BH, S, K]
|
||||||
|
const float *__restrict__ v, // [BH, S, V]
|
||||||
|
const float *__restrict__ g, // [BH, S]
|
||||||
|
const float *__restrict__ beta, // [BH, S]
|
||||||
|
float *__restrict__ state, // [BH, K, V]
|
||||||
|
float *__restrict__ output, // [BH, S, V]
|
||||||
|
int seq_len, int v_dim) {
|
||||||
|
|
||||||
|
const int v_tile = blockIdx.x; // which V-tile
|
||||||
|
const int bh = blockIdx.y; // batch*head index
|
||||||
|
const int tid = threadIdx.x; // thread within tile [0, BV)
|
||||||
|
const int v_idx = v_tile * BV + tid; // global V index
|
||||||
|
|
||||||
|
if (v_idx >= v_dim)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Pointers for this (batch, head)
|
||||||
|
const float *q_bh = q + bh * seq_len * BK;
|
||||||
|
const float *k_bh = k + bh * seq_len * BK;
|
||||||
|
const float *v_bh = v + bh * seq_len * v_dim;
|
||||||
|
const float *g_bh = g + bh * seq_len;
|
||||||
|
const float *beta_bh = beta + bh * seq_len;
|
||||||
|
float *state_bh = state + bh * BK * v_dim;
|
||||||
|
float *out_bh = output + bh * seq_len * v_dim;
|
||||||
|
|
||||||
|
// Shared memory: k_buf[BK] + q_buf[BK]
|
||||||
|
__shared__ float k_buf[BK];
|
||||||
|
__shared__ float q_buf[BK];
|
||||||
|
|
||||||
|
// Load state column into registers — BK is compile-time, so this is
|
||||||
|
// a true register array (not spilled to local memory)
|
||||||
|
float s[BK];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] = state_bh[j * v_dim + v_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = 0; t < seq_len; t++) {
|
||||||
|
// Collaboratively load k_t into shared memory
|
||||||
|
// BK / BV loads per thread (e.g. 128/64 = 2)
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
k_buf[j] = k_bh[t * BK + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Load scalars for this timestep
|
||||||
|
float decay = expf(g_bh[t]);
|
||||||
|
float beta_t = beta_bh[t];
|
||||||
|
float v_t = v_bh[t * v_dim + v_idx];
|
||||||
|
|
||||||
|
// Fused pass 1: decay state + compute kv_mem
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] *= decay;
|
||||||
|
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delta rule
|
||||||
|
float delta = (v_t - kv_mem) * beta_t;
|
||||||
|
|
||||||
|
// Collaboratively load q_t into shared memory
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
q_buf[j] = q_bh[t * BK + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Fused pass 2: update state + compute output
|
||||||
|
float y_t = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||||
|
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_bh[t * v_dim + v_idx] = y_t;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write state back
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
state_bh[j * v_dim + v_idx] = s[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback kernel: runtime k_dim, still V-tiled for occupancy
|
||||||
|
template <int BV, int MAX_K>
|
||||||
|
__global__ void gated_delta_rule_recurrence_kernel_fallback(
|
||||||
|
const float *__restrict__ q, const float *__restrict__ k,
|
||||||
|
const float *__restrict__ v, const float *__restrict__ g,
|
||||||
|
const float *__restrict__ beta, float *__restrict__ state,
|
||||||
|
float *__restrict__ output, int seq_len, int k_dim, int v_dim) {
|
||||||
|
|
||||||
|
const int v_tile = blockIdx.x;
|
||||||
|
const int bh = blockIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int v_idx = v_tile * BV + tid;
|
||||||
|
|
||||||
|
if (v_idx >= v_dim)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const float *q_bh = q + bh * seq_len * k_dim;
|
||||||
|
const float *k_bh = k + bh * seq_len * k_dim;
|
||||||
|
const float *v_bh = v + bh * seq_len * v_dim;
|
||||||
|
const float *g_bh = g + bh * seq_len;
|
||||||
|
const float *beta_bh = beta + bh * seq_len;
|
||||||
|
float *state_bh = state + bh * k_dim * v_dim;
|
||||||
|
float *out_bh = output + bh * seq_len * v_dim;
|
||||||
|
|
||||||
|
extern __shared__ float shared[];
|
||||||
|
float *k_buf = shared;
|
||||||
|
float *q_buf = shared + k_dim;
|
||||||
|
|
||||||
|
float s[MAX_K];
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
s[j] = state_bh[j * v_dim + v_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = 0; t < seq_len; t++) {
|
||||||
|
for (int j = tid; j < k_dim; j += BV) {
|
||||||
|
k_buf[j] = k_bh[t * k_dim + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float decay = expf(g_bh[t]);
|
||||||
|
float beta_t = beta_bh[t];
|
||||||
|
float v_t = v_bh[t * v_dim + v_idx];
|
||||||
|
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
s[j] *= decay;
|
||||||
|
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||||
|
}
|
||||||
|
|
||||||
|
float delta = (v_t - kv_mem) * beta_t;
|
||||||
|
|
||||||
|
for (int j = tid; j < k_dim; j += BV) {
|
||||||
|
q_buf[j] = q_bh[t * k_dim + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float y_t = 0.0f;
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||||
|
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_bh[t * v_dim + v_idx] = y_t;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < k_dim; j++) {
|
||||||
|
state_bh[j * v_dim + v_idx] = s[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void gated_delta_rule_recurrence(const float *q, const float *k,
|
||||||
|
const float *v, const float *g,
|
||||||
|
const float *beta, float *state,
|
||||||
|
float *output, int bh, int seq_len,
|
||||||
|
int k_dim, int v_dim,
|
||||||
|
int64_t stream) {
|
||||||
|
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
if (k_dim == 128) {
|
||||||
|
// Fast path for Qwen3-Next (k_dim=128)
|
||||||
|
constexpr int BK = 128;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||||
|
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||||
|
v_dim);
|
||||||
|
} else if (k_dim == 64) {
|
||||||
|
// Fast path for models with k_dim=64
|
||||||
|
constexpr int BK = 64;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||||
|
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||||
|
v_dim);
|
||||||
|
} else {
|
||||||
|
// Fallback for other k_dim values (runtime loop, still V-tiled)
|
||||||
|
constexpr int BV = 64;
|
||||||
|
constexpr int MAX_K = 256;
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
size_t smem = 2 * k_dim * sizeof(float);
|
||||||
|
gated_delta_rule_recurrence_kernel_fallback<BV, MAX_K>
|
||||||
|
<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||||
|
seq_len, k_dim, v_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization)
|
||||||
|
//
|
||||||
|
// Processes prefill tokens in BT-token chunks instead of one at a time.
|
||||||
|
// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation,
|
||||||
|
// forward substitution (triangular solve), output computation, and state
|
||||||
|
// update.
|
||||||
|
//
|
||||||
|
// Same thread model as Kernel 1: one block per (v_tile, batch*head),
|
||||||
|
// one thread per V-column. Each thread owns BK registers of state.
|
||||||
|
//
|
||||||
|
// Shared memory holds:
|
||||||
|
// k_chunk[BT * BK] -- key vectors for current chunk
|
||||||
|
// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix
|
||||||
|
// gcum[BT] -- cumulative sum of g within chunk
|
||||||
|
// beta_s[BT] -- beta values for chunk
|
||||||
|
// q_buf[BK] -- q vector (loaded one row at a time)
|
||||||
|
//
|
||||||
|
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||||
|
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <int BT, int BK, int BV>
|
||||||
|
__global__ void
|
||||||
|
chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K]
|
||||||
|
const float *__restrict__ k, // [BH, S, K]
|
||||||
|
const float *__restrict__ v, // [BH, S, V]
|
||||||
|
const float *__restrict__ g, // [BH, S]
|
||||||
|
const float *__restrict__ beta, // [BH, S]
|
||||||
|
float *__restrict__ state, // [BH, K, V]
|
||||||
|
float *__restrict__ output, // [BH, S, V]
|
||||||
|
int seq_len, int v_dim) {
|
||||||
|
|
||||||
|
const int v_tile = blockIdx.x;
|
||||||
|
const int bh = blockIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int v_idx = v_tile * BV + tid;
|
||||||
|
|
||||||
|
if (v_idx >= v_dim)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const int num_chunks = (seq_len + BT - 1) / BT;
|
||||||
|
|
||||||
|
// Pointers for this (batch, head)
|
||||||
|
const float *q_bh = q + bh * seq_len * BK;
|
||||||
|
const float *k_bh = k + bh * seq_len * BK;
|
||||||
|
const float *v_bh = v + bh * seq_len * v_dim;
|
||||||
|
const float *g_bh = g + bh * seq_len;
|
||||||
|
const float *beta_bh = beta + bh * seq_len;
|
||||||
|
float *state_bh = state + bh * BK * v_dim;
|
||||||
|
float *out_bh = output + bh * seq_len * v_dim;
|
||||||
|
|
||||||
|
// Dynamic shared memory layout
|
||||||
|
extern __shared__ float smem[];
|
||||||
|
float *k_chunk = smem; // [BT * BK]
|
||||||
|
float *kk_dot = smem + BT * BK; // [BT * BT]
|
||||||
|
float *gcum = smem + BT * BK + BT * BT; // [BT]
|
||||||
|
float *beta_s = gcum + BT; // [BT]
|
||||||
|
float *q_buf = beta_s + BT; // [BK]
|
||||||
|
|
||||||
|
// Load state column into registers
|
||||||
|
float s[BK];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
s[j] = state_bh[j * v_dim + v_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-thread register array for corrected deltas
|
||||||
|
float delta[BT];
|
||||||
|
|
||||||
|
for (int c = 0; c < num_chunks; c++) {
|
||||||
|
const int chunk_start = c * BT;
|
||||||
|
const int chunk_len = min(BT, seq_len - chunk_start);
|
||||||
|
|
||||||
|
// === Phase 1: Cooperative load of k, beta, g into shared memory ===
|
||||||
|
for (int t = 0; t < chunk_len; t++) {
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tid < chunk_len) {
|
||||||
|
beta_s[tid] = beta_bh[chunk_start + tid];
|
||||||
|
gcum[tid] = g_bh[chunk_start + tid];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// === Phase 1b: Parallel prefix sum of g (Hillis-Steele) ===
|
||||||
|
for (int stride = 1; stride < BT; stride <<= 1) {
|
||||||
|
float prev = 0.0f;
|
||||||
|
if (tid < chunk_len && (int)tid >= stride)
|
||||||
|
prev = gcum[tid - stride];
|
||||||
|
__syncthreads();
|
||||||
|
if (tid < chunk_len && (int)tid >= stride)
|
||||||
|
gcum[tid] += prev;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i ===
|
||||||
|
// Only lower-triangular entries needed (strictly lower)
|
||||||
|
for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) {
|
||||||
|
int i = idx / chunk_len;
|
||||||
|
int j = idx % chunk_len;
|
||||||
|
if (j < i) {
|
||||||
|
float dot = 0.0f;
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot);
|
||||||
|
}
|
||||||
|
kk_dot[i * BT + j] = dot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// === Phase 3: Forward substitution (per V-column, in registers) ===
|
||||||
|
// Computes corrected delta values via triangular solve
|
||||||
|
for (int i = 0; i < chunk_len; i++) {
|
||||||
|
float v_i = v_bh[(chunk_start + i) * v_dim + v_idx];
|
||||||
|
float decay_i = expf(gcum[i]);
|
||||||
|
float beta_i = beta_s[i];
|
||||||
|
|
||||||
|
// Inter-chunk contribution: state @ k[i] with decay
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem);
|
||||||
|
}
|
||||||
|
|
||||||
|
float rhs = beta_i * (v_i - kv_mem);
|
||||||
|
|
||||||
|
// Subtract lower-triangular contributions (intra-chunk)
|
||||||
|
for (int j = 0; j < i; j++) {
|
||||||
|
float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]);
|
||||||
|
rhs -= a_ij * delta[j];
|
||||||
|
}
|
||||||
|
delta[i] = rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Phase 4: Output computation (per V-column) ===
|
||||||
|
for (int i = 0; i < chunk_len; i++) {
|
||||||
|
// Cooperatively load q[i] into shared
|
||||||
|
for (int j = tid; j < BK; j += BV) {
|
||||||
|
q_buf[j] = q_bh[(chunk_start + i) * BK + j];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float decay_i = expf(gcum[i]);
|
||||||
|
|
||||||
|
// Inter-chunk contribution: q[i] @ (state * decay)
|
||||||
|
float o_val = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] *
|
||||||
|
// exp(gcum[i] - gcum[j])
|
||||||
|
for (int j = 0; j <= i; j++) {
|
||||||
|
float qk_dot = 0.0f;
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot);
|
||||||
|
}
|
||||||
|
o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_bh[(chunk_start + i) * v_dim + v_idx] = o_val;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Phase 5: State update for next chunk ===
|
||||||
|
float g_total = gcum[chunk_len - 1];
|
||||||
|
#pragma unroll
|
||||||
|
for (int d = 0; d < BK; d++) {
|
||||||
|
float s_new = s[d] * expf(g_total);
|
||||||
|
for (int t = 0; t < chunk_len; t++) {
|
||||||
|
s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]);
|
||||||
|
}
|
||||||
|
s[d] = s_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write final state back
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < BK; j++) {
|
||||||
|
state_bh[j * v_dim + v_idx] = s[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void chunked_gated_delta_rule_recurrence(
|
||||||
|
const float *q, const float *k, const float *v, const float *g,
|
||||||
|
const float *beta, float *state, float *output, int bh, int seq_len,
|
||||||
|
int k_dim, int v_dim, int64_t stream) {
|
||||||
|
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
if (k_dim == 128) {
|
||||||
|
constexpr int BT = 64;
|
||||||
|
constexpr int BK = 128;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
// Shared memory: BT*BK + BT*BT + BT + BT + BK floats
|
||||||
|
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||||
|
|
||||||
|
// Request extended shared memory
|
||||||
|
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||||
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
smem);
|
||||||
|
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||||
|
seq_len, v_dim);
|
||||||
|
} else if (k_dim == 64) {
|
||||||
|
constexpr int BT = 64;
|
||||||
|
constexpr int BK = 64;
|
||||||
|
constexpr int BV = 64;
|
||||||
|
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||||
|
|
||||||
|
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||||
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
smem);
|
||||||
|
|
||||||
|
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||||
|
dim3 block(BV);
|
||||||
|
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||||
|
seq_len, v_dim);
|
||||||
|
} else {
|
||||||
|
// Fallback: use the sequential kernel for unsupported k_dim
|
||||||
|
gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len,
|
||||||
|
k_dim, v_dim, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 2a: causal_conv1d_update (decode path, single step)
|
||||||
|
//
|
||||||
|
// Each thread handles one channel: shift conv_state left by 1,
|
||||||
|
// insert new value, dot product with weight, apply SiLU.
|
||||||
|
//
|
||||||
|
// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||||
|
// conv_state: [B, conv_dim, kernel_size] (in/out)
|
||||||
|
// output: [B, conv_dim, 1]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void causal_conv1d_update_kernel(
|
||||||
|
const T *__restrict__ x, // [B, conv_dim, 1]
|
||||||
|
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||||
|
T *__restrict__ conv_state, // [B, conv_dim, kernel_size]
|
||||||
|
T *__restrict__ output, // [B, conv_dim, 1]
|
||||||
|
int batch_size, int conv_dim, int kernel_size) {
|
||||||
|
|
||||||
|
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int b = blockIdx.y;
|
||||||
|
|
||||||
|
if (ch >= conv_dim || b >= batch_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Pointer to this batch/channel's conv state
|
||||||
|
T *cs = conv_state + (b * conv_dim + ch) * kernel_size;
|
||||||
|
const T *w = weight + ch * kernel_size;
|
||||||
|
|
||||||
|
// Shift state left by 1
|
||||||
|
for (int i = 0; i < kernel_size - 1; i++) {
|
||||||
|
cs[i] = cs[i + 1];
|
||||||
|
}
|
||||||
|
// Insert new value
|
||||||
|
cs[kernel_size - 1] = x[b * conv_dim + ch];
|
||||||
|
|
||||||
|
// Dot product with weight
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (int i = 0; i < kernel_size; i++) {
|
||||||
|
acc += (float)cs[i] * (float)w[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// SiLU activation: x * sigmoid(x)
|
||||||
|
float sig = 1.0f / (1.0f + expf(-acc));
|
||||||
|
float result = acc * sig;
|
||||||
|
|
||||||
|
output[b * conv_dim + ch] = (T)result;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void causal_conv1d_update(const void *x, const void *weight,
|
||||||
|
void *conv_state, void *output,
|
||||||
|
int batch_size, int conv_dim,
|
||||||
|
int kernel_size, int dtype,
|
||||||
|
int64_t stream) {
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((conv_dim + 255) / 256, batch_size);
|
||||||
|
|
||||||
|
if (dtype == 0) {
|
||||||
|
// f16
|
||||||
|
causal_conv1d_update_kernel<__half><<<grid, block, 0, custream>>>(
|
||||||
|
(const __half *)x, (const __half *)weight, (__half *)conv_state,
|
||||||
|
(__half *)output, batch_size, conv_dim, kernel_size);
|
||||||
|
} else {
|
||||||
|
// bf16
|
||||||
|
causal_conv1d_update_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||||
|
(__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size,
|
||||||
|
conv_dim, kernel_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 2b: causal_conv1d_full (prefill path)
|
||||||
|
//
|
||||||
|
// Each thread handles one (channel, position): causal window with
|
||||||
|
// zero-padding, dot product with weight, SiLU.
|
||||||
|
// A second pass writes the conv_state from the last kernel_size positions.
|
||||||
|
//
|
||||||
|
// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||||
|
// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void causal_conv1d_full_kernel(
|
||||||
|
const T *__restrict__ x, // [B, conv_dim, S]
|
||||||
|
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||||
|
T *__restrict__ output, // [B, conv_dim, S]
|
||||||
|
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||||
|
|
||||||
|
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int pos = blockIdx.y;
|
||||||
|
const int b = blockIdx.z;
|
||||||
|
|
||||||
|
if (ch >= conv_dim || pos >= seq_len || b >= batch_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||||
|
const T *w = weight + ch * kernel_size;
|
||||||
|
|
||||||
|
// Causal convolution: sum over kernel_size window ending at pos
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (int i = 0; i < kernel_size; i++) {
|
||||||
|
int src_pos = pos - (kernel_size - 1) + i;
|
||||||
|
float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f;
|
||||||
|
acc += x_val * (float)w[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// SiLU
|
||||||
|
float sig = 1.0f / (1.0f + expf(-acc));
|
||||||
|
float result = acc * sig;
|
||||||
|
|
||||||
|
output[(b * conv_dim + ch) * seq_len + pos] = (T)result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void save_conv_state_kernel(
|
||||||
|
const T *__restrict__ x, // [B, conv_dim, S]
|
||||||
|
T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size]
|
||||||
|
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||||
|
|
||||||
|
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int b = blockIdx.y;
|
||||||
|
|
||||||
|
if (ch >= conv_dim || b >= batch_size)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||||
|
T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size;
|
||||||
|
|
||||||
|
// Save last kernel_size positions (zero-pad if seq_len < kernel_size)
|
||||||
|
int pad = kernel_size - seq_len;
|
||||||
|
for (int i = 0; i < kernel_size; i++) {
|
||||||
|
if (i < pad) {
|
||||||
|
cs[i] = (T)0.0f;
|
||||||
|
} else {
|
||||||
|
cs[i] = x_bch[seq_len - kernel_size + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void causal_conv1d_full(const void *x, const void *weight,
|
||||||
|
void *conv_state_out, void *output,
|
||||||
|
int batch_size, int conv_dim, int seq_len,
|
||||||
|
int kernel_size, int dtype, int64_t stream) {
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
|
||||||
|
// Main convolution kernel
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((conv_dim + 255) / 256, seq_len, batch_size);
|
||||||
|
|
||||||
|
if (dtype == 0) {
|
||||||
|
causal_conv1d_full_kernel<__half><<<grid, block, 0, custream>>>(
|
||||||
|
(const __half *)x, (const __half *)weight, (__half *)output, batch_size,
|
||||||
|
conv_dim, seq_len, kernel_size);
|
||||||
|
// Save conv state
|
||||||
|
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||||
|
save_conv_state_kernel<__half><<<grid2, block, 0, custream>>>(
|
||||||
|
(const __half *)x, (__half *)conv_state_out, batch_size, conv_dim,
|
||||||
|
seq_len, kernel_size);
|
||||||
|
} else {
|
||||||
|
causal_conv1d_full_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||||
|
(__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size);
|
||||||
|
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||||
|
save_conv_state_kernel<__nv_bfloat16><<<grid2, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size,
|
||||||
|
conv_dim, seq_len, kernel_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Kernel 3: fused_gdn_gating
|
||||||
|
//
|
||||||
|
// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||||
|
// a_log and dt_bias are per-head (broadcast over batch*seq).
|
||||||
|
//
|
||||||
|
// b, a: [total] a_log, dt_bias: [num_heads]
|
||||||
|
// beta_out, g_out: [total]
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void
|
||||||
|
fused_gdn_gating_kernel(const T *__restrict__ b, // [total]
|
||||||
|
const T *__restrict__ a, // [total]
|
||||||
|
const float *__restrict__ a_log, // [num_heads]
|
||||||
|
const float *__restrict__ dt_bias, // [num_heads]
|
||||||
|
T *__restrict__ beta_out, // [total]
|
||||||
|
T *__restrict__ g_out, // [total]
|
||||||
|
int total_elements, int num_heads) {
|
||||||
|
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx >= total_elements)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Head index: elements are laid out as [..., num_heads]
|
||||||
|
int head_idx = idx % num_heads;
|
||||||
|
|
||||||
|
// beta = sigmoid(b)
|
||||||
|
float b_val = (float)b[idx];
|
||||||
|
float beta = 1.0f / (1.0f + expf(-b_val));
|
||||||
|
|
||||||
|
// g = -exp(a_log) * softplus(a + dt_bias)
|
||||||
|
float a_val = (float)a[idx];
|
||||||
|
float a_log_val = a_log[head_idx];
|
||||||
|
float dt_bias_val = dt_bias[head_idx];
|
||||||
|
|
||||||
|
float sp_input = a_val + dt_bias_val;
|
||||||
|
float softplus_val = logf(1.0f + expf(sp_input));
|
||||||
|
float g_val = -expf(a_log_val) * softplus_val;
|
||||||
|
|
||||||
|
beta_out[idx] = (T)beta;
|
||||||
|
g_out[idx] = (T)g_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void fused_gdn_gating(const void *b, const void *a,
|
||||||
|
const float *a_log, const float *dt_bias,
|
||||||
|
void *beta_out, void *g_out,
|
||||||
|
int total_elements, int num_heads, int dtype,
|
||||||
|
int64_t stream) {
|
||||||
|
const cudaStream_t custream = (cudaStream_t)stream;
|
||||||
|
dim3 block(256);
|
||||||
|
dim3 grid((total_elements + 255) / 256);
|
||||||
|
|
||||||
|
if (dtype == 0) {
|
||||||
|
fused_gdn_gating_kernel<__half><<<grid, block, 0, custream>>>(
|
||||||
|
(const __half *)b, (const __half *)a, a_log, dt_bias,
|
||||||
|
(__half *)beta_out, (__half *)g_out, total_elements, num_heads);
|
||||||
|
} else {
|
||||||
|
fused_gdn_gating_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||||
|
(const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias,
|
||||||
|
(__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements,
|
||||||
|
num_heads);
|
||||||
|
}
|
||||||
|
}
|
||||||
486
crates/neuron/src/cuda/gdn.rs
Normal file
486
crates/neuron/src/cuda/gdn.rs
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`.
|
||||||
|
//!
|
||||||
|
//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||||
|
//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in
|
||||||
|
//! this file are this header comment — the FFI path module name is
|
||||||
|
//! `crate::cuda::ffi`, identical to upstream's layout.
|
||||||
|
|
||||||
|
#![allow(clippy::cast_possible_truncation)]
|
||||||
|
|
||||||
|
use candle_core::{Result, Tensor};
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use candle_core::DType;
|
||||||
|
|
||||||
|
/// CUDA-accelerated gated delta rule recurrence.
|
||||||
|
///
|
||||||
|
/// Inputs (all contiguous, f32):
|
||||||
|
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||||
|
/// state: [BH, K, V] (mutated in place)
|
||||||
|
///
|
||||||
|
/// Returns: output [BH, S, V]
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn gated_delta_rule_recurrence_cuda(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
|
||||||
|
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||||
|
let v_dim = v.dim(2)?;
|
||||||
|
|
||||||
|
let dev = q.device().as_cuda_device()?;
|
||||||
|
|
||||||
|
let (q_s, q_l) = q.storage_and_layout();
|
||||||
|
let q_s = match &*q_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("q must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let q_offset = q_l.start_offset();
|
||||||
|
|
||||||
|
let (k_s, k_l) = k.storage_and_layout();
|
||||||
|
let k_s = match &*k_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("k must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let k_offset = k_l.start_offset();
|
||||||
|
|
||||||
|
let (v_s, v_l) = v.storage_and_layout();
|
||||||
|
let v_s = match &*v_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("v must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let v_offset = v_l.start_offset();
|
||||||
|
|
||||||
|
let (g_s, g_l) = g.storage_and_layout();
|
||||||
|
let g_s = match &*g_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("g must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let g_offset = g_l.start_offset();
|
||||||
|
|
||||||
|
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||||
|
let beta_s = match &*beta_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("beta must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let beta_offset = beta_l.start_offset();
|
||||||
|
|
||||||
|
let (state_s, state_l) = state.storage_and_layout();
|
||||||
|
let state_s = match &*state_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("state must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let state_offset = state_l.start_offset();
|
||||||
|
|
||||||
|
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::gated_delta_rule_recurrence(
|
||||||
|
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||||
|
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||||
|
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||||
|
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||||
|
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||||
|
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||||
|
bh as i32,
|
||||||
|
seq_len as i32,
|
||||||
|
k_dim as i32,
|
||||||
|
v_dim as i32,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The kernel wrote state in-place via the raw pointer; rewrap
|
||||||
|
// (state tensor's underlying CudaSlice was modified directly)
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
Ok(Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(bh, seq_len, v_dim),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn gated_delta_rule_recurrence_cuda(
|
||||||
|
_q: &Tensor,
|
||||||
|
_k: &Tensor,
|
||||||
|
_v: &Tensor,
|
||||||
|
_g: &Tensor,
|
||||||
|
_beta: &Tensor,
|
||||||
|
_state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization).
|
||||||
|
///
|
||||||
|
/// Processes prefill tokens in 64-token chunks instead of one at a time.
|
||||||
|
/// Same interface as `gated_delta_rule_recurrence_cuda`.
|
||||||
|
///
|
||||||
|
/// Inputs (all contiguous, f32):
|
||||||
|
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||||
|
/// state: [BH, K, V] (mutated in place)
|
||||||
|
///
|
||||||
|
/// Returns: output [BH, S, V]
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
|
||||||
|
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||||
|
let v_dim = v.dim(2)?;
|
||||||
|
|
||||||
|
let dev = q.device().as_cuda_device()?;
|
||||||
|
|
||||||
|
let (q_s, q_l) = q.storage_and_layout();
|
||||||
|
let q_s = match &*q_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("q must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let q_offset = q_l.start_offset();
|
||||||
|
|
||||||
|
let (k_s, k_l) = k.storage_and_layout();
|
||||||
|
let k_s = match &*k_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("k must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let k_offset = k_l.start_offset();
|
||||||
|
|
||||||
|
let (v_s, v_l) = v.storage_and_layout();
|
||||||
|
let v_s = match &*v_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("v must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let v_offset = v_l.start_offset();
|
||||||
|
|
||||||
|
let (g_s, g_l) = g.storage_and_layout();
|
||||||
|
let g_s = match &*g_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("g must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let g_offset = g_l.start_offset();
|
||||||
|
|
||||||
|
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||||
|
let beta_s = match &*beta_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("beta must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let beta_offset = beta_l.start_offset();
|
||||||
|
|
||||||
|
let (state_s, state_l) = state.storage_and_layout();
|
||||||
|
let state_s = match &*state_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("state must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let state_offset = state_l.start_offset();
|
||||||
|
|
||||||
|
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::chunked_gated_delta_rule_recurrence(
|
||||||
|
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||||
|
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||||
|
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||||
|
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||||
|
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||||
|
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||||
|
bh as i32,
|
||||||
|
seq_len as i32,
|
||||||
|
k_dim as i32,
|
||||||
|
v_dim as i32,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
Ok(Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(bh, seq_len, v_dim),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||||
|
_q: &Tensor,
|
||||||
|
_k: &Tensor,
|
||||||
|
_v: &Tensor,
|
||||||
|
_g: &Tensor,
|
||||||
|
_beta: &Tensor,
|
||||||
|
_state: &mut Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA-accelerated causal conv1d (both update and full paths).
|
||||||
|
///
|
||||||
|
/// For update (is_update=true):
|
||||||
|
/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||||
|
/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update)
|
||||||
|
/// Returns: (output [B, conv_dim, 1], updated conv_state)
|
||||||
|
///
|
||||||
|
/// For full (is_update=false):
|
||||||
|
/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||||
|
/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size])
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn causal_conv1d_cuda(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: &Tensor,
|
||||||
|
kernel_size: usize,
|
||||||
|
is_update: bool,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
use core::ffi::c_void;
|
||||||
|
fn cuda_fwd<
|
||||||
|
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||||
|
>(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: &Tensor,
|
||||||
|
kernel_size: usize,
|
||||||
|
is_update: bool,
|
||||||
|
dtype_code: i32,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let dev = x.device().as_cuda_device()?;
|
||||||
|
let (batch_size, conv_dim, seq_len) = x.dims3()?;
|
||||||
|
|
||||||
|
let (x_s, x_l) = x.storage_and_layout();
|
||||||
|
let x_s = match &*x_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("x must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let x_offset = x_l.start_offset();
|
||||||
|
|
||||||
|
let (w_s, w_l) = weight.storage_and_layout();
|
||||||
|
let w_s = match &*w_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("weight must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let w_offset = w_l.start_offset();
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
if is_update {
|
||||||
|
// Clone conv_state so the kernel can mutate it in place
|
||||||
|
let conv_state_new = conv_state.clone();
|
||||||
|
|
||||||
|
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim) }?;
|
||||||
|
|
||||||
|
// Scope the borrow of conv_state_new so we can move it later
|
||||||
|
{
|
||||||
|
let (cs_s, cs_l) = conv_state_new.storage_and_layout();
|
||||||
|
let cs_s = match &*cs_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("conv_state must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let cs_offset = cs_l.start_offset();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::causal_conv1d_update(
|
||||||
|
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||||
|
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||||
|
cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||||
|
batch_size as i32,
|
||||||
|
conv_dim as i32,
|
||||||
|
kernel_size as i32,
|
||||||
|
dtype_code,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
let output = Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(batch_size, conv_dim, 1usize),
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok((output, conv_state_new))
|
||||||
|
} else {
|
||||||
|
// Full path: allocate new conv_state and output
|
||||||
|
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * seq_len) }?;
|
||||||
|
let cs_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * kernel_size) }?;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::causal_conv1d_full(
|
||||||
|
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||||
|
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||||
|
cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void,
|
||||||
|
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||||
|
batch_size as i32,
|
||||||
|
conv_dim as i32,
|
||||||
|
seq_len as i32,
|
||||||
|
kernel_size as i32,
|
||||||
|
dtype_code,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||||
|
let output = Tensor::from((
|
||||||
|
candle::Storage::Cuda(output_storage),
|
||||||
|
(batch_size, conv_dim, seq_len),
|
||||||
|
));
|
||||||
|
|
||||||
|
let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone());
|
||||||
|
let new_conv_state = Tensor::from((
|
||||||
|
candle::Storage::Cuda(cs_storage),
|
||||||
|
(batch_size, conv_dim, kernel_size),
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok((output, new_conv_state))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match x.dtype() {
|
||||||
|
DType::F16 => cuda_fwd::<half::f16>(x, weight, conv_state, kernel_size, is_update, 0),
|
||||||
|
DType::BF16 => cuda_fwd::<half::bf16>(x, weight, conv_state, kernel_size, is_update, 1),
|
||||||
|
other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn causal_conv1d_cuda(
|
||||||
|
_x: &Tensor,
|
||||||
|
_weight: &Tensor,
|
||||||
|
_conv_state: &Tensor,
|
||||||
|
_kernel_size: usize,
|
||||||
|
_is_update: bool,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
candle_core::bail!("causal_conv1d_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA-accelerated fused GDN gating computation.
|
||||||
|
///
|
||||||
|
/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||||
|
///
|
||||||
|
/// b, a: [total_elements] in f16/bf16
|
||||||
|
/// a_log, dt_bias: [num_heads] in f32
|
||||||
|
///
|
||||||
|
/// Returns: (beta, g) in original dtype
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn fused_gdn_gating_cuda(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle_core as candle;
|
||||||
|
use core::ffi::c_void;
|
||||||
|
|
||||||
|
fn cuda_fwd<
|
||||||
|
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||||
|
>(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
dtype_code: i32,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let total_elements = b.elem_count();
|
||||||
|
let num_heads = a_log.elem_count();
|
||||||
|
let shape = b.shape().clone();
|
||||||
|
let dev = b.device().as_cuda_device()?;
|
||||||
|
|
||||||
|
let (b_s, b_l) = b.storage_and_layout();
|
||||||
|
let b_s = match &*b_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("b must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let b_offset = b_l.start_offset();
|
||||||
|
|
||||||
|
let (a_s, a_l) = a.storage_and_layout();
|
||||||
|
let a_s = match &*a_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||||
|
_ => candle::bail!("a must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let a_offset = a_l.start_offset();
|
||||||
|
|
||||||
|
let (alog_s, alog_l) = a_log.storage_and_layout();
|
||||||
|
let alog_s = match &*alog_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("a_log must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let alog_offset = alog_l.start_offset();
|
||||||
|
|
||||||
|
let (dtb_s, dtb_l) = dt_bias.storage_and_layout();
|
||||||
|
let dtb_s = match &*dtb_s {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("dt_bias must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
let dtb_offset = dtb_l.start_offset();
|
||||||
|
|
||||||
|
let beta_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||||
|
let g_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
crate::cuda::ffi::fused_gdn_gating(
|
||||||
|
b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void,
|
||||||
|
a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void,
|
||||||
|
alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32,
|
||||||
|
dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32,
|
||||||
|
beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void,
|
||||||
|
g_buf.device_ptr(g_buf.stream()).0 as *mut c_void,
|
||||||
|
total_elements as i32,
|
||||||
|
num_heads as i32,
|
||||||
|
dtype_code,
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone());
|
||||||
|
let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone()));
|
||||||
|
|
||||||
|
let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone());
|
||||||
|
let g = Tensor::from((candle::Storage::Cuda(g_storage), shape));
|
||||||
|
|
||||||
|
Ok((beta, g))
|
||||||
|
}
|
||||||
|
|
||||||
|
match b.dtype() {
|
||||||
|
DType::F16 => cuda_fwd::<half::f16>(b, a, a_log, dt_bias, 0),
|
||||||
|
DType::BF16 => cuda_fwd::<half::bf16>(b, a, a_log, dt_bias, 1),
|
||||||
|
other => candle_core::bail!(
|
||||||
|
"fused_gdn_gating_cuda only supports f16/bf16, got {:?}",
|
||||||
|
other
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn fused_gdn_gating_cuda(
|
||||||
|
_b: &Tensor,
|
||||||
|
_a: &Tensor,
|
||||||
|
_a_log: &Tensor,
|
||||||
|
_dt_bias: &Tensor,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature")
|
||||||
|
}
|
||||||
15
crates/neuron/src/cuda/mod.rs
Normal file
15
crates/neuron/src/cuda/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//! CUDA kernels and their Rust wrappers.
|
||||||
|
//!
|
||||||
|
//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`)
|
||||||
|
//! inference performance — the Gated DeltaNet kernels ported from
|
||||||
|
//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu`
|
||||||
|
//! file alongside this module; `build.rs` compiles them all into a
|
||||||
|
//! static lib via `cudaforge` and links it under the `cuda` feature.
|
||||||
|
//!
|
||||||
|
//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM,
|
||||||
|
//! etc.) they land here in their own `.cu` + `.rs` pairs.
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub mod ffi;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub mod gdn;
|
||||||
275
crates/neuron/src/discovery.rs
Normal file
275
crates/neuron/src/discovery.rs
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
//! GPU discovery via nvidia-smi and system info gathering.
|
||||||
|
//!
|
||||||
|
//! Pure parsing functions are separated from command execution for testability.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use cortex_core::discovery::{DeviceHealth, DeviceInfo, DiscoveryResponse};
|
||||||
|
|
||||||
|
const NVIDIA_SMI_DISCOVERY_QUERY: &str = "index,name,memory.total,compute_cap,driver_version";
|
||||||
|
const NVIDIA_SMI_HEALTH_QUERY: &str =
|
||||||
|
"index,memory.used,memory.free,utilization.gpu,temperature.gpu";
|
||||||
|
|
||||||
|
// ── Pure parsing functions (testable without GPU) ───────────────────
|
||||||
|
|
||||||
|
/// Parse nvidia-smi CSV output for device discovery.
|
||||||
|
///
|
||||||
|
/// Expected input format (one line per GPU):
|
||||||
|
/// ```text
|
||||||
|
/// 0, NVIDIA GeForce RTX 5090, 32614, 12.0, 570.86.16
|
||||||
|
/// 1, NVIDIA GeForce RTX 5090, 32614, 12.0, 570.86.16
|
||||||
|
/// ```
|
||||||
|
pub fn parse_gpu_info(csv_output: &str) -> Result<Vec<DeviceInfo>> {
|
||||||
|
let mut devices = Vec::new();
|
||||||
|
for line in csv_output.lines() {
|
||||||
|
let line = line.trim();
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let parts: Vec<&str> = line.splitn(5, ',').map(|s| s.trim()).collect();
|
||||||
|
if parts.len() < 5 {
|
||||||
|
anyhow::bail!("malformed nvidia-smi line (expected 5 fields): {line}");
|
||||||
|
}
|
||||||
|
devices.push(DeviceInfo {
|
||||||
|
index: parts[0]
|
||||||
|
.parse()
|
||||||
|
.with_context(|| format!("invalid GPU index: {}", parts[0]))?,
|
||||||
|
name: parts[1].to_string(),
|
||||||
|
vram_total_mb: parts[2]
|
||||||
|
.parse()
|
||||||
|
.with_context(|| format!("invalid VRAM: {}", parts[2]))?,
|
||||||
|
compute_capability: parts[3].to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(devices)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the driver version from nvidia-smi discovery output.
|
||||||
|
/// Takes the driver_version field from the first GPU line.
|
||||||
|
pub fn parse_driver_version(csv_output: &str) -> Option<String> {
|
||||||
|
let line = csv_output.lines().find(|l| !l.trim().is_empty())?;
|
||||||
|
let parts: Vec<&str> = line.splitn(5, ',').map(|s| s.trim()).collect();
|
||||||
|
if parts.len() >= 5 {
|
||||||
|
Some(parts[4].to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse the CUDA version from `nvcc --version` output.
|
||||||
|
///
|
||||||
|
/// Expected line: `Cuda compilation tools, release 12.8, V12.8.93`
|
||||||
|
pub fn parse_cuda_version(nvcc_output: &str) -> Option<String> {
|
||||||
|
for line in nvcc_output.lines() {
|
||||||
|
if line.contains("release") {
|
||||||
|
// Extract "12.8" from "release 12.8,"
|
||||||
|
let after_release = line.split("release").nth(1)?;
|
||||||
|
let version = after_release.trim().split(',').next()?.trim();
|
||||||
|
if !version.is_empty() {
|
||||||
|
return Some(version.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse nvidia-smi CSV output for health metrics.
|
||||||
|
///
|
||||||
|
/// Expected input format (one line per GPU):
|
||||||
|
/// ```text
|
||||||
|
/// 0, 8192, 24372, 45, 62
|
||||||
|
/// ```
|
||||||
|
pub fn parse_health_info(csv_output: &str) -> Result<Vec<DeviceHealth>> {
|
||||||
|
let mut devices = Vec::new();
|
||||||
|
for line in csv_output.lines() {
|
||||||
|
let line = line.trim();
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let parts: Vec<&str> = line.splitn(5, ',').map(|s| s.trim()).collect();
|
||||||
|
if parts.len() < 5 {
|
||||||
|
anyhow::bail!("malformed nvidia-smi health line (expected 5 fields): {line}");
|
||||||
|
}
|
||||||
|
devices.push(DeviceHealth {
|
||||||
|
index: parts[0].parse().with_context(|| "invalid index")?,
|
||||||
|
vram_used_mb: parts[1].parse().with_context(|| "invalid vram_used")?,
|
||||||
|
vram_free_mb: parts[2].parse().with_context(|| "invalid vram_free")?,
|
||||||
|
utilization_pct: parts[3].parse().with_context(|| "invalid utilization")?,
|
||||||
|
temp_c: parts[4].parse().with_context(|| "invalid temp")?,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(devices)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Command execution wrappers ──────────────────────────────────────
|
||||||
|
|
||||||
|
async fn run_command(cmd: &str, args: &[&str]) -> Result<String> {
|
||||||
|
let output = tokio::process::Command::new(cmd)
|
||||||
|
.args(args)
|
||||||
|
.output()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("failed to execute {cmd}"))?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
anyhow::bail!("{cmd} failed: {stderr}");
|
||||||
|
}
|
||||||
|
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_command_optional(cmd: &str, args: &[&str]) -> Option<String> {
|
||||||
|
run_command(cmd, args).await.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Discover the full system: hostname, OS, kernel, GPUs, CUDA version.
|
||||||
|
/// Handles nvidia-smi not found gracefully (returns empty devices).
|
||||||
|
pub async fn discover_system() -> Result<DiscoveryResponse> {
|
||||||
|
let hostname = run_command("uname", &["-n"])
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "unknown".into())
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
let os = run_command("uname", &["-s"])
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "unknown".into())
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
let kernel = run_command("uname", &["-r"])
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "unknown".into())
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let (devices, driver_version) = match run_command_optional(
|
||||||
|
"nvidia-smi",
|
||||||
|
&[
|
||||||
|
&format!("--query-gpu={NVIDIA_SMI_DISCOVERY_QUERY}"),
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Some(output) => {
|
||||||
|
let devs = parse_gpu_info(&output).unwrap_or_default();
|
||||||
|
let driver = parse_driver_version(&output);
|
||||||
|
(devs, driver)
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
tracing::info!("nvidia-smi not found — no GPU devices discovered");
|
||||||
|
(vec![], None)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let cuda_version = match run_command_optional("nvcc", &["--version"]).await {
|
||||||
|
Some(output) => parse_cuda_version(&output),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(DiscoveryResponse {
|
||||||
|
hostname,
|
||||||
|
os,
|
||||||
|
kernel,
|
||||||
|
cuda_version,
|
||||||
|
driver_version,
|
||||||
|
devices,
|
||||||
|
harnesses: vec![], // populated by harness registry in Phase 8
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run nvidia-smi health query and parse the output.
|
||||||
|
pub async fn query_health() -> Result<Vec<DeviceHealth>> {
|
||||||
|
let output = run_command(
|
||||||
|
"nvidia-smi",
|
||||||
|
&[
|
||||||
|
&format!("--query-gpu={NVIDIA_SMI_HEALTH_QUERY}"),
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
parse_health_info(&output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Tests ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_gpu_info_single_gpu() {
|
||||||
|
let csv = "0, NVIDIA GeForce RTX 4090, 24564, 8.9, 570.86.16\n";
|
||||||
|
let devices = parse_gpu_info(csv).unwrap();
|
||||||
|
assert_eq!(devices.len(), 1);
|
||||||
|
assert_eq!(devices[0].index, 0);
|
||||||
|
assert_eq!(devices[0].name, "NVIDIA GeForce RTX 4090");
|
||||||
|
assert_eq!(devices[0].vram_total_mb, 24564);
|
||||||
|
assert_eq!(devices[0].compute_capability, "8.9");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_gpu_info_multi_gpu() {
|
||||||
|
let csv = "\
|
||||||
|
0, NVIDIA GeForce RTX 5090, 32614, 12.0, 570.86.16\n\
|
||||||
|
1, NVIDIA GeForce RTX 5090, 32614, 12.0, 570.86.16\n";
|
||||||
|
let devices = parse_gpu_info(csv).unwrap();
|
||||||
|
assert_eq!(devices.len(), 2);
|
||||||
|
assert_eq!(devices[0].index, 0);
|
||||||
|
assert_eq!(devices[1].index, 1);
|
||||||
|
assert_eq!(devices[0].vram_total_mb, 32614);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_gpu_info_empty() {
|
||||||
|
let devices = parse_gpu_info("").unwrap();
|
||||||
|
assert!(devices.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_gpu_info_malformed() {
|
||||||
|
let result = parse_gpu_info("garbage data");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_driver_version() {
|
||||||
|
let csv = "0, NVIDIA GeForce RTX 4090, 24564, 8.9, 570.86.16\n";
|
||||||
|
assert_eq!(parse_driver_version(csv), Some("570.86.16".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_cuda_version() {
|
||||||
|
let nvcc = "\
|
||||||
|
nvcc: NVIDIA (R) Cuda compiler driver\n\
|
||||||
|
Copyright (c) 2005-2024 NVIDIA Corporation\n\
|
||||||
|
Built on Thu_Sep_12_02:18:05_PDT_2024\n\
|
||||||
|
Cuda compilation tools, release 12.8, V12.8.93\n";
|
||||||
|
assert_eq!(parse_cuda_version(nvcc), Some("12.8".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_cuda_version_missing() {
|
||||||
|
assert_eq!(parse_cuda_version("unrelated output"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_health_info() {
|
||||||
|
let csv = "0, 8192, 16372, 45, 62\n";
|
||||||
|
let health = parse_health_info(csv).unwrap();
|
||||||
|
assert_eq!(health.len(), 1);
|
||||||
|
assert_eq!(health[0].index, 0);
|
||||||
|
assert_eq!(health[0].vram_used_mb, 8192);
|
||||||
|
assert_eq!(health[0].vram_free_mb, 16372);
|
||||||
|
assert_eq!(health[0].utilization_pct, 45);
|
||||||
|
assert_eq!(health[0].temp_c, 62);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_health_info_multi_gpu() {
|
||||||
|
let csv = "\
|
||||||
|
0, 8192, 24372, 45, 62\n\
|
||||||
|
1, 4096, 28468, 30, 58\n";
|
||||||
|
let health = parse_health_info(csv).unwrap();
|
||||||
|
assert_eq!(health.len(), 2);
|
||||||
|
assert_eq!(health[1].vram_used_mb, 4096);
|
||||||
|
assert_eq!(health[1].temp_c, 58);
|
||||||
|
}
|
||||||
|
}
|
||||||
23
crates/neuron/src/harness/arch/mod.rs
Normal file
23
crates/neuron/src/harness/arch/mod.rs
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
//! Custom architecture implementations.
|
||||||
|
//!
|
||||||
|
//! When candle-transformers ships a model family unchanged
|
||||||
|
//! (`models::llama`, `models::qwen3`, `models::qwen3_moe`, etc.), the
|
||||||
|
//! handler in `harness/candle.rs` just wraps the upstream type in a
|
||||||
|
//! `ModelArch` variant.
|
||||||
|
//!
|
||||||
|
//! When candle has nothing for the architecture and we have to write
|
||||||
|
//! it from scratch — Qwen3-Next / Qwen3.6 (`qwen3_5`) being the
|
||||||
|
//! motivating example — the implementation lands here, one file per
|
||||||
|
//! architecture.
|
||||||
|
//!
|
||||||
|
//! Each architecture module is expected to expose:
|
||||||
|
//! - A `Config` type deserialised from the model's `config.json`
|
||||||
|
//! (some architectures nest the real hyperparams under `text_config`,
|
||||||
|
//! in which case the module owns the unwrapping).
|
||||||
|
//! - A `ForCausalLM` struct with `new`, `forward(&mut self, x, offset)
|
||||||
|
//! -> Result<Tensor>`, and `clear_kv_cache(&mut self)`.
|
||||||
|
//!
|
||||||
|
//! TP-aware analogues live in `harness/tp/tp_<family>.rs` and follow
|
||||||
|
//! the pattern set by `tp_qwen3.rs`.
|
||||||
|
|
||||||
|
pub mod qwen3_5;
|
||||||
117
crates/neuron/src/harness/arch/qwen3_5/decoder.rs
Normal file
117
crates/neuron/src/harness/arch/qwen3_5/decoder.rs
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
//! Qwen3-Next decoder layer.
|
||||||
|
//!
|
||||||
|
//! Standard pre-norm transformer block (LN → attention → residual →
|
||||||
|
//! LN → MLP → residual) where the attention slot dispatches on the
|
||||||
|
//! per-layer `layer_types[i]` value in the config:
|
||||||
|
//!
|
||||||
|
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
|
||||||
|
//! gate + RoPE + KV cache).
|
||||||
|
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
|
||||||
|
//! causal conv + per-head state).
|
||||||
|
//!
|
||||||
|
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
|
||||||
|
//! linear_attention. `full_attention_interval` in the config is a
|
||||||
|
//! hint; `layer_types` is authoritative.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
use super::full_attn::Qwen3_5Attention;
|
||||||
|
use super::linear_attn::GatedDeltaNet;
|
||||||
|
use super::mlp::Qwen3_5MLP;
|
||||||
|
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||||
|
use super::rope::RotaryEmbedding;
|
||||||
|
|
||||||
|
/// One of the two attention flavours sitting in a decoder layer's
|
||||||
|
/// attention slot. Full-attention layers need the rotary table and
|
||||||
|
/// take an attention mask; linear-attention layers carry their own
|
||||||
|
/// recurrent state and ignore the mask.
|
||||||
|
enum AttentionKind {
|
||||||
|
Full(Qwen3_5Attention),
|
||||||
|
Linear(GatedDeltaNet),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Qwen3_5DecoderLayer {
|
||||||
|
input_layernorm: Qwen3_5RmsNorm,
|
||||||
|
post_attention_layernorm: Qwen3_5RmsNorm,
|
||||||
|
mlp: Qwen3_5MLP,
|
||||||
|
attention: AttentionKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5DecoderLayer {
|
||||||
|
pub fn load(
|
||||||
|
cfg: &TextConfig,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
layer_idx: usize,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let layer_type = cfg
|
||||||
|
.layer_types
|
||||||
|
.get(layer_idx)
|
||||||
|
.map(String::as_str)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"layer_types[{layer_idx}] missing (have {} entries)",
|
||||||
|
cfg.layer_types.len()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let attention = match layer_type {
|
||||||
|
"full_attention" => {
|
||||||
|
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
|
||||||
|
}
|
||||||
|
"linear_attention" => {
|
||||||
|
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unknown layer_type '{other}' for layer {layer_idx} (expected \
|
||||||
|
'full_attention' or 'linear_attention')"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
|
||||||
|
let input_layernorm =
|
||||||
|
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
||||||
|
&vb.pp("post_attention_layernorm"),
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
input_layernorm,
|
||||||
|
post_attention_layernorm,
|
||||||
|
mlp,
|
||||||
|
attention,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let h = self.input_layernorm.forward(x)?;
|
||||||
|
let attn_out = match &mut self.attention {
|
||||||
|
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
|
||||||
|
// Linear attention ignores attn_mask + offset; its causal
|
||||||
|
// structure is baked into the recurrent state lifecycle.
|
||||||
|
AttentionKind::Linear(net) => net.forward(&h)?,
|
||||||
|
};
|
||||||
|
let x = (x + attn_out)?;
|
||||||
|
let h2 = self.post_attention_layernorm.forward(&x)?;
|
||||||
|
let h2 = self.mlp.forward(&h2)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
match &mut self.attention {
|
||||||
|
AttentionKind::Full(attn) => attn.clear_kv_cache(),
|
||||||
|
AttentionKind::Linear(net) => net.clear_kv_cache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
179
crates/neuron/src/harness/arch/qwen3_5/full_attn.rs
Normal file
179
crates/neuron/src/harness/arch/qwen3_5/full_attn.rs
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
//! Qwen3-Next's `full_attention` layer.
|
||||||
|
//!
|
||||||
|
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
|
||||||
|
//!
|
||||||
|
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
|
||||||
|
//! to `num_heads * head_dim * 2`. The second half is reshaped to
|
||||||
|
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
|
||||||
|
//! attention output is pointwise-multiplied by this gate before
|
||||||
|
//! `o_proj`. Effectively a per-head per-position attenuation on
|
||||||
|
//! the attention output.
|
||||||
|
//!
|
||||||
|
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
|
||||||
|
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
|
||||||
|
//! checkpoints expect the `(1 + w)` form.
|
||||||
|
//!
|
||||||
|
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
|
||||||
|
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
|
||||||
|
//! `rope::RotaryEmbedding`), and the usual causal mask.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::kv_cache::ConcatKvCache;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use candle_transformers::utils::repeat_kv;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||||
|
use super::rope::RotaryEmbedding;
|
||||||
|
|
||||||
|
pub struct Qwen3_5Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
q_norm: Qwen3_5RmsNorm,
|
||||||
|
k_norm: Qwen3_5RmsNorm,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: ConcatKvCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5Attention {
|
||||||
|
pub fn load(
|
||||||
|
cfg: &TextConfig,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
|
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"num_attention_heads ({num_heads}) must be a positive multiple of \
|
||||||
|
num_key_value_heads ({num_kv_heads})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
|
|
||||||
|
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
|
||||||
|
// the gate (see attn_output_gate notes above).
|
||||||
|
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
|
||||||
|
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||||
|
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||||
|
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
|
||||||
|
|
||||||
|
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
let hidden_size = head_dim * num_heads;
|
||||||
|
let kv_cache = ConcatKvCache::new(2);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
hidden_size,
|
||||||
|
rotary,
|
||||||
|
kv_cache,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
|
// 1. q_proj — widened output, split into (query, gate).
|
||||||
|
let q_raw = self
|
||||||
|
.q_proj
|
||||||
|
.forward(x)?
|
||||||
|
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
|
||||||
|
let q = q_raw.narrow(3, 0, self.head_dim)?;
|
||||||
|
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
|
||||||
|
// Flatten the gate's head dim back into hidden_size for the
|
||||||
|
// post-attention pointwise multiply.
|
||||||
|
let gate = gate
|
||||||
|
.contiguous()?
|
||||||
|
.reshape((b, l, self.num_heads * self.head_dim))?;
|
||||||
|
|
||||||
|
// 2. q_norm + k_norm + reshape to (B, H, L, D).
|
||||||
|
let q = self.q_norm.forward(&q.contiguous()?)?;
|
||||||
|
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
|
||||||
|
|
||||||
|
let k = self
|
||||||
|
.k_proj
|
||||||
|
.forward(x)?
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
|
||||||
|
let k = self.k_norm.forward(&k.contiguous()?)?;
|
||||||
|
let k = k.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
|
let v = self
|
||||||
|
.v_proj
|
||||||
|
.forward(x)?
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
|
||||||
|
// 3. RoPE on q, k.
|
||||||
|
let (q, k) = self.rotary.apply(&q, &k, offset)?;
|
||||||
|
|
||||||
|
// 4. KV cache.
|
||||||
|
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||||
|
|
||||||
|
// 5. GQA repeat (cheap shape op).
|
||||||
|
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||||
|
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
|
// 6. Scaled dot-product + causal mask.
|
||||||
|
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
|
||||||
|
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
if let Some(m) = attn_mask {
|
||||||
|
scores = scores.broadcast_add(m)?;
|
||||||
|
}
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
let ctx = probs.matmul(&v)?; // (B, H, L, D)
|
||||||
|
|
||||||
|
// 7. Reshape back, apply the output gate, project.
|
||||||
|
let ctx = ctx
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?
|
||||||
|
.reshape((b, l, self.hidden_size))?;
|
||||||
|
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
|
||||||
|
let gated = (ctx * gate_sig)?;
|
||||||
|
self.o_proj.forward(&gated)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_linear_no_bias(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
name: &str,
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.pp(name)
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
793
crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs
Normal file
793
crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs
Normal file
@@ -0,0 +1,793 @@
|
|||||||
|
//! Qwen3-Next's `linear_attention` layer: Gated DeltaNet.
|
||||||
|
//!
|
||||||
|
//! The recurrent linear-attention block that occupies 3 out of every 4
|
||||||
|
//! decoder layers in Qwen3.6 (`layer_types[i] == "linear_attention"`).
|
||||||
|
//! Implemented against the reference Python in
|
||||||
|
//! `huggingface/transformers/src/transformers/models/qwen3_5/modeling_qwen3_5.py`
|
||||||
|
//! (class `Qwen3_5GatedDeltaNet`).
|
||||||
|
//!
|
||||||
|
//! ## Block structure
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! x ── in_proj_qkv ── transpose ─► (B, conv_dim, L)
|
||||||
|
//! │
|
||||||
|
//! ┌──────────────── conv_state ──┤ prepend cached state (decode)
|
||||||
|
//! ▼
|
||||||
|
//! depthwise causal Conv1d (k=4) → SiLU
|
||||||
|
//! │
|
||||||
|
//! └─ split → q (k_dim), k (k_dim), v (v_dim) ─► per-head reshape
|
||||||
|
//!
|
||||||
|
//! x ── in_proj_z ────────────────► z (gate for the output RMSNorm)
|
||||||
|
//! x ── in_proj_b ── sigmoid ─────► beta (per-head per-token update rate)
|
||||||
|
//! x ── in_proj_a ── softplus ────► g (decay; see eqn below)
|
||||||
|
//!
|
||||||
|
//! g = -exp(A_log) * softplus(a + dt_bias) # discretisation
|
||||||
|
//! beta = sigmoid(b)
|
||||||
|
//!
|
||||||
|
//! (q, k) ─── L2norm ─── delta rule loop ──── core_attn_out
|
||||||
|
//! (per-token, per-head):
|
||||||
|
//! state *= exp(g_t)
|
||||||
|
//! mem = state^T · k_t
|
||||||
|
//! delta = (v_t - mem) * beta_t
|
||||||
|
//! state += outer(k_t, delta)
|
||||||
|
//! out_t = state^T · q_t
|
||||||
|
//!
|
||||||
|
//! core_attn_out ── RMSNormGated(z) ── reshape ── out_proj ── y
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! ## State
|
||||||
|
//!
|
||||||
|
//! Two tensors persist across decode steps:
|
||||||
|
//! - `conv_state`: `(B, conv_dim, conv_kernel_size)` — left-padded
|
||||||
|
//! tail of the input to the depthwise conv, so the next causal
|
||||||
|
//! window has the right left-context.
|
||||||
|
//! - `recurrent_state`: `(B, num_v_heads, head_k_dim, head_v_dim)` —
|
||||||
|
//! the delta-rule outer-product memory.
|
||||||
|
//!
|
||||||
|
//! Both are cleared via [`GatedDeltaNet::clear_kv_cache`] at the start
|
||||||
|
//! of every new request.
|
||||||
|
//!
|
||||||
|
//! ## Performance note
|
||||||
|
//!
|
||||||
|
//! This impl is the **recurrent** delta-rule for both prefill and
|
||||||
|
//! decode — i.e. the algorithm in `torch_recurrent_gated_delta_rule`.
|
||||||
|
//! Correctness-first. The chunked algorithm (chunk_size=64) in
|
||||||
|
//! `torch_chunk_gated_delta_rule` is a perf optimisation for long
|
||||||
|
//! prefill; can be added later without changing the surface.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
use super::RopeParameters;
|
||||||
|
use super::TextConfig;
|
||||||
|
use super::rmsnorm::{Qwen3_5RmsNormGated, l2norm};
|
||||||
|
|
||||||
|
/// Per-rank, per-layer state for the linear-attention block.
|
||||||
|
///
|
||||||
|
/// `conv_state` is left-padded with zeros on first use; `recurrent_state`
|
||||||
|
/// is initialised lazily to zeros once we know the batch size.
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct GatedDeltaNetState {
|
||||||
|
pub conv_state: Option<Tensor>,
|
||||||
|
pub recurrent_state: Option<Tensor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct GatedDeltaNet {
|
||||||
|
// Projections.
|
||||||
|
in_proj_qkv: Linear,
|
||||||
|
in_proj_z: Linear,
|
||||||
|
in_proj_b: Linear,
|
||||||
|
in_proj_a: Linear,
|
||||||
|
out_proj: Linear,
|
||||||
|
|
||||||
|
// Depthwise causal Conv1d weight; shape (conv_dim, 1, kernel_size).
|
||||||
|
// No bias (Python sets bias=False).
|
||||||
|
conv1d_weight: Tensor,
|
||||||
|
|
||||||
|
// Per-head discretisation params.
|
||||||
|
dt_bias: Tensor,
|
||||||
|
a_log: Tensor,
|
||||||
|
|
||||||
|
// Output norm + gate.
|
||||||
|
norm: Qwen3_5RmsNormGated,
|
||||||
|
|
||||||
|
// Shape hyperparams (cached for forward).
|
||||||
|
num_v_heads: usize,
|
||||||
|
num_k_heads: usize,
|
||||||
|
head_k_dim: usize,
|
||||||
|
head_v_dim: usize,
|
||||||
|
key_dim: usize,
|
||||||
|
value_dim: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
|
||||||
|
// Recurrent state held inline. Each request resets via
|
||||||
|
// `clear_kv_cache`; otherwise the state persists across forwards
|
||||||
|
// and the per-token offset advances naturally.
|
||||||
|
state: GatedDeltaNetState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GatedDeltaNet {
|
||||||
|
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let num_v_heads = cfg.linear_num_value_heads;
|
||||||
|
let num_k_heads = cfg.linear_num_key_heads;
|
||||||
|
let head_k_dim = cfg.linear_key_head_dim;
|
||||||
|
let head_v_dim = cfg.linear_value_head_dim;
|
||||||
|
let conv_kernel_size = cfg.linear_conv_kernel_dim;
|
||||||
|
|
||||||
|
if num_v_heads == 0 || num_k_heads == 0 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Qwen3-Next linear_num_*_heads must be set; got v={num_v_heads}, k={num_k_heads}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !num_v_heads.is_multiple_of(num_k_heads) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"linear_num_value_heads ({num_v_heads}) must be a multiple of \
|
||||||
|
linear_num_key_heads ({num_k_heads}) for GQA-style head expansion"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let key_dim = head_k_dim * num_k_heads;
|
||||||
|
let value_dim = head_v_dim * num_v_heads;
|
||||||
|
let conv_dim = key_dim * 2 + value_dim;
|
||||||
|
|
||||||
|
// ----- Linear projections (all `bias=False` in the reference). -----
|
||||||
|
let in_proj_qkv = load_linear_no_bias(vb, "in_proj_qkv", cfg.hidden_size, conv_dim)?;
|
||||||
|
let in_proj_z = load_linear_no_bias(vb, "in_proj_z", cfg.hidden_size, value_dim)?;
|
||||||
|
let in_proj_b = load_linear_no_bias(vb, "in_proj_b", cfg.hidden_size, num_v_heads)?;
|
||||||
|
let in_proj_a = load_linear_no_bias(vb, "in_proj_a", cfg.hidden_size, num_v_heads)?;
|
||||||
|
let out_proj = load_linear_no_bias(vb, "out_proj", value_dim, cfg.hidden_size)?;
|
||||||
|
|
||||||
|
// ----- Conv1d weight (depthwise, bias=False). -----
|
||||||
|
let conv1d_weight = vb
|
||||||
|
.pp("conv1d")
|
||||||
|
.get((conv_dim, 1, conv_kernel_size), "weight")
|
||||||
|
.with_context(|| format!("load '{}/conv1d/weight'", vb.prefix()))?;
|
||||||
|
|
||||||
|
// ----- dt_bias + A_log: per-head 1D params. -----
|
||||||
|
let dt_bias = vb
|
||||||
|
.get(num_v_heads, "dt_bias")
|
||||||
|
.with_context(|| format!("load '{}/dt_bias'", vb.prefix()))?;
|
||||||
|
let a_log = vb
|
||||||
|
.get(num_v_heads, "A_log")
|
||||||
|
.with_context(|| format!("load '{}/A_log'", vb.prefix()))?;
|
||||||
|
|
||||||
|
// ----- Output gated RMSNorm (per-head_v_dim). -----
|
||||||
|
let norm = Qwen3_5RmsNormGated::load(&vb.pp("norm"), head_v_dim, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
in_proj_qkv,
|
||||||
|
in_proj_z,
|
||||||
|
in_proj_b,
|
||||||
|
in_proj_a,
|
||||||
|
out_proj,
|
||||||
|
conv1d_weight,
|
||||||
|
dt_bias,
|
||||||
|
a_log,
|
||||||
|
norm,
|
||||||
|
num_v_heads,
|
||||||
|
num_k_heads,
|
||||||
|
head_k_dim,
|
||||||
|
head_v_dim,
|
||||||
|
key_dim,
|
||||||
|
value_dim,
|
||||||
|
conv_dim,
|
||||||
|
conv_kernel_size,
|
||||||
|
state: GatedDeltaNetState::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.state = GatedDeltaNetState::default();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `x` shape: `(B, L, hidden_size)`. Returns the same shape.
|
||||||
|
pub fn forward(&mut self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let (batch_size, seq_len, _) = x.dims3()?;
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let device = x.device().clone();
|
||||||
|
|
||||||
|
// ----- Projections. -----
|
||||||
|
// mixed_qkv: (B, L, conv_dim)
|
||||||
|
let mixed_qkv = self.in_proj_qkv.forward(x)?;
|
||||||
|
// (B, conv_dim, L) for the conv1d.
|
||||||
|
let mixed_qkv_chw = mixed_qkv.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
|
// z: (B, L, value_dim) → (B, L, num_v_heads, head_v_dim)
|
||||||
|
let z = self.in_proj_z.forward(x)?.reshape((
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
self.num_v_heads,
|
||||||
|
self.head_v_dim,
|
||||||
|
))?;
|
||||||
|
|
||||||
|
// b, a: (B, L, num_v_heads)
|
||||||
|
let b = self.in_proj_b.forward(x)?;
|
||||||
|
let a = self.in_proj_a.forward(x)?;
|
||||||
|
|
||||||
|
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
|
||||||
|
// Dispatches to a cuda kernel that fuses conv1d + silu when
|
||||||
|
// available; falls back to candle's `conv1d` + `silu` on cpu.
|
||||||
|
let (conv_out, new_state) = run_causal_conv1d(
|
||||||
|
&mixed_qkv_chw,
|
||||||
|
&self.conv1d_weight,
|
||||||
|
self.state.conv_state.take(),
|
||||||
|
batch_size,
|
||||||
|
self.conv_dim,
|
||||||
|
seq_len,
|
||||||
|
self.conv_kernel_size,
|
||||||
|
)?;
|
||||||
|
self.state.conv_state = Some(new_state);
|
||||||
|
// Back to (B, L, conv_dim).
|
||||||
|
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
|
// ----- Split into q, k, v. -----
|
||||||
|
let q = mixed_qkv.narrow(2, 0, self.key_dim)?;
|
||||||
|
let k = mixed_qkv.narrow(2, self.key_dim, self.key_dim)?;
|
||||||
|
let v = mixed_qkv.narrow(2, 2 * self.key_dim, self.value_dim)?;
|
||||||
|
|
||||||
|
let q = q.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
|
||||||
|
let k = k.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
|
||||||
|
let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?;
|
||||||
|
|
||||||
|
// ----- beta + g (per-head, per-token gates). -----
|
||||||
|
// Fused on cuda; per-op Rust on cpu. Both paths produce:
|
||||||
|
// beta = sigmoid(b)
|
||||||
|
// g = -exp(A_log) * softplus(a + dt_bias)
|
||||||
|
let (beta, g) = run_fused_gating(&b, &a, &self.a_log, &self.dt_bias)?;
|
||||||
|
|
||||||
|
// ----- GQA-style key expansion if num_v_heads > num_k_heads. -----
|
||||||
|
let (q, k) = if self.num_v_heads > self.num_k_heads {
|
||||||
|
let rep = self.num_v_heads / self.num_k_heads;
|
||||||
|
(
|
||||||
|
repeat_interleave(&q, rep, 2)?,
|
||||||
|
repeat_interleave(&k, rep, 2)?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(q, k)
|
||||||
|
};
|
||||||
|
|
||||||
|
// ----- L2-norm on q, k (use_qk_l2norm_in_kernel=True in ref). -----
|
||||||
|
let q = l2norm(&q, 1e-6)?;
|
||||||
|
let k = l2norm(&k, 1e-6)?;
|
||||||
|
|
||||||
|
// ----- Recurrent delta rule. -----
|
||||||
|
// Inputs: q, k (B, L, H, D_k); v (B, L, H, D_v); g (B, L, H); beta (B, L, H).
|
||||||
|
// The reference transposes to (B, H, L, D) before the loop. We
|
||||||
|
// do the same — it makes per-token indexing trivial.
|
||||||
|
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D_k)
|
||||||
|
let k = k.transpose(1, 2)?.contiguous()?;
|
||||||
|
let v = v.transpose(1, 2)?.contiguous()?; // (B, H, L, D_v)
|
||||||
|
let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
||||||
|
let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
||||||
|
|
||||||
|
// Pre-scale q by 1/sqrt(D_k) once. Everything goes to f32 here
|
||||||
|
// since the delta rule mixes broadcast_mul ops that candle won't
|
||||||
|
// accept across mixed dtypes. On the cuda gating path both beta
|
||||||
|
// and g come back in model dtype; on the cpu path g is already
|
||||||
|
// f32 — both casts are cheap idempotent ops.
|
||||||
|
let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt();
|
||||||
|
let q = (q.to_dtype(candle_core::DType::F32)? * scale)?;
|
||||||
|
let k = k.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let v = v.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let g = g.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let beta = beta.to_dtype(candle_core::DType::F32)?;
|
||||||
|
|
||||||
|
// Initialise the recurrent state from cache or zeros.
|
||||||
|
let state_init = match self.state.recurrent_state.take() {
|
||||||
|
Some(s) => s.to_dtype(candle_core::DType::F32)?,
|
||||||
|
None => Tensor::zeros(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
self.num_v_heads,
|
||||||
|
self.head_k_dim,
|
||||||
|
self.head_v_dim,
|
||||||
|
),
|
||||||
|
candle_core::DType::F32,
|
||||||
|
&device,
|
||||||
|
)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
// The delta-rule body: cuda-accelerated `gated_delta_rule_recurrence`
|
||||||
|
// kernel when we have a cuda device + the kernels are linked in,
|
||||||
|
// pure-Rust per-token fallback otherwise.
|
||||||
|
let (core_attn_out, new_state) = run_delta_rule(
|
||||||
|
&q,
|
||||||
|
&k,
|
||||||
|
&v,
|
||||||
|
&g,
|
||||||
|
&beta,
|
||||||
|
state_init,
|
||||||
|
batch_size,
|
||||||
|
self.num_v_heads,
|
||||||
|
seq_len,
|
||||||
|
self.head_k_dim,
|
||||||
|
self.head_v_dim,
|
||||||
|
)?;
|
||||||
|
// Stash the updated recurrent state for the next call.
|
||||||
|
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
|
||||||
|
|
||||||
|
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
|
||||||
|
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
|
||||||
|
let core_attn_out = core_attn_out.to_dtype(dtype)?;
|
||||||
|
let core_attn_flat =
|
||||||
|
core_attn_out.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
|
||||||
|
let z_flat = z.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
|
||||||
|
|
||||||
|
// RMSNormGated: (out * silu(z) * weight) with the norm.
|
||||||
|
let normed = self.norm.forward(&core_attn_flat, &z_flat)?;
|
||||||
|
let normed = normed.reshape((batch_size, seq_len, self.num_v_heads * self.head_v_dim))?;
|
||||||
|
|
||||||
|
// Output projection: (B, L, value_dim) → (B, L, hidden_size).
|
||||||
|
self.out_proj.forward(&normed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the per-token delta-rule recurrence.
|
||||||
|
///
|
||||||
|
/// `q`, `k`: `(B, H, L, D_k)` (F32). `v`: `(B, H, L, D_v)`. `g`,
|
||||||
|
/// `beta`: `(B, H, L)`. `state`: `(B, H, D_k, D_v)`.
|
||||||
|
///
|
||||||
|
/// Returns `(core_attn_out: (B, H, L, D_v), state: (B, H, D_k, D_v))`,
|
||||||
|
/// both F32. Caller is responsible for cast back to model dtype.
|
||||||
|
///
|
||||||
|
/// Cuda path: dispatches to the `gated_delta_rule_recurrence` kernel
|
||||||
|
/// ported from `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/gdn.cu`.
|
||||||
|
/// All five inputs must be cuda f32 tensors. The kernel is V-tiled
|
||||||
|
/// with compile-time BK; one block per (V-tile, batch*head) and one
|
||||||
|
/// thread per V-column. Each thread holds BK state floats in
|
||||||
|
/// registers — eliminates the launch-overhead floor we hit with
|
||||||
|
/// candle's per-op dispatch (was ~12s/token on Qwen3.6-27B).
|
||||||
|
///
|
||||||
|
/// CPU path: pure-Rust per-token loop. Correct, slow.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn run_delta_rule(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: Tensor,
|
||||||
|
batch_size: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
head_k_dim: usize,
|
||||||
|
head_v_dim: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
// Only dispatch to the kernel if the inputs are on a CUDA
|
||||||
|
// device — CPU tests fall back to the Rust loop below.
|
||||||
|
if q.device().is_cuda() {
|
||||||
|
return run_delta_rule_cuda(
|
||||||
|
q, k, v, g, beta, state, batch_size, num_heads, seq_len, head_k_dim, head_v_dim,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = (batch_size, num_heads, head_k_dim, head_v_dim);
|
||||||
|
run_delta_rule_rust(q, k, v, g, beta, state, seq_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CUDA path. Flattens (B, H, ...) → (BH, ...) at the kernel boundary
|
||||||
|
/// (the kernel uses BH = batch*heads as its outer batch axis) and
|
||||||
|
/// reshapes the kernel's outputs back to (B, H, ...) for the caller.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_delta_rule_cuda(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
state: Tensor,
|
||||||
|
batch_size: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
head_k_dim: usize,
|
||||||
|
head_v_dim: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let q_bh = q.flatten(0, 1)?.contiguous()?;
|
||||||
|
let k_bh = k.flatten(0, 1)?.contiguous()?;
|
||||||
|
let v_bh = v.flatten(0, 1)?.contiguous()?;
|
||||||
|
let g_bh = g.flatten(0, 1)?.contiguous()?;
|
||||||
|
let beta_bh = beta.flatten(0, 1)?.contiguous()?;
|
||||||
|
let mut state_bh = state.flatten(0, 1)?.contiguous()?;
|
||||||
|
// For long prefills, the chunked kernel (BT=64) processes a chunk
|
||||||
|
// of tokens at a time instead of one-by-one — same delta-rule math,
|
||||||
|
// far fewer block launches. Threshold matches mistralrs.
|
||||||
|
const CHUNK_THRESHOLD: usize = 64;
|
||||||
|
let output_bh = if seq_len >= CHUNK_THRESHOLD {
|
||||||
|
crate::cuda::gdn::chunked_gated_delta_rule_recurrence_cuda(
|
||||||
|
&q_bh,
|
||||||
|
&k_bh,
|
||||||
|
&v_bh,
|
||||||
|
&g_bh,
|
||||||
|
&beta_bh,
|
||||||
|
&mut state_bh,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
|
||||||
|
&q_bh,
|
||||||
|
&k_bh,
|
||||||
|
&v_bh,
|
||||||
|
&g_bh,
|
||||||
|
&beta_bh,
|
||||||
|
&mut state_bh,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
|
||||||
|
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
|
||||||
|
Ok((core_attn_out, new_state))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn run_delta_rule_rust(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
g: &Tensor,
|
||||||
|
beta: &Tensor,
|
||||||
|
mut state: Tensor,
|
||||||
|
seq_len: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
use candle_core::IndexOp;
|
||||||
|
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
|
||||||
|
for t in 0..seq_len {
|
||||||
|
let q_t = q.i((.., .., t, ..))?;
|
||||||
|
let k_t = k.i((.., .., t, ..))?;
|
||||||
|
let v_t = v.i((.., .., t, ..))?;
|
||||||
|
let g_t = g.i((.., .., t))?;
|
||||||
|
let beta_t = beta.i((.., .., t))?;
|
||||||
|
let decay = g_t
|
||||||
|
.exp()?
|
||||||
|
.unsqueeze(candle_core::D::Minus1)?
|
||||||
|
.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
state = state.broadcast_mul(&decay)?;
|
||||||
|
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
|
||||||
|
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
|
||||||
|
let delta_row = delta.unsqueeze(2)?;
|
||||||
|
let outer = k_col.broadcast_mul(&delta_row)?;
|
||||||
|
state = (state + outer)?;
|
||||||
|
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?;
|
||||||
|
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
|
||||||
|
outputs.push(out_t.unsqueeze(2)?);
|
||||||
|
}
|
||||||
|
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
|
||||||
|
Ok((core_attn_out, state))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Depthwise causal conv1d + SiLU, with rolling `conv_state`.
|
||||||
|
///
|
||||||
|
/// `x`: `(B, conv_dim, L)` model dtype (f16/bf16 on cuda, anything on cpu).
|
||||||
|
/// `weight`: `(conv_dim, 1, kernel_size)` model dtype.
|
||||||
|
/// `conv_state`: `Some((B, conv_dim, kernel_size))` for decode continuation,
|
||||||
|
/// or `None` for fresh prefill.
|
||||||
|
///
|
||||||
|
/// Returns `(conv_out: (B, conv_dim, L), new_conv_state: (B, conv_dim, kernel_size))`.
|
||||||
|
/// SiLU is baked in.
|
||||||
|
///
|
||||||
|
/// Cuda path: dispatches to `causal_conv1d_update` (decode, seq_len=1 with
|
||||||
|
/// existing state) or `causal_conv1d_full` (prefill / first call), both
|
||||||
|
/// ported from mistralrs `gdn.cu`. Each kernel fuses the depthwise conv
|
||||||
|
/// and SiLU activation in one launch — that's ~4× fewer cuda launches per
|
||||||
|
/// linear-attention layer than the candle `conv1d` + `silu` combo.
|
||||||
|
///
|
||||||
|
/// CPU path: the original prepend-narrow-conv1d-silu sequence.
|
||||||
|
pub(crate) fn run_causal_conv1d(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: Option<Tensor>,
|
||||||
|
batch_size: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
if x.device().is_cuda() {
|
||||||
|
return run_causal_conv1d_cuda(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
conv_state,
|
||||||
|
batch_size,
|
||||||
|
conv_dim,
|
||||||
|
seq_len,
|
||||||
|
conv_kernel_size,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run_causal_conv1d_rust(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
conv_state,
|
||||||
|
batch_size,
|
||||||
|
conv_dim,
|
||||||
|
seq_len,
|
||||||
|
conv_kernel_size,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn run_causal_conv1d_cuda(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: Option<Tensor>,
|
||||||
|
batch_size: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
// Kernel expects weight as (conv_dim, kernel_size) — squeeze the
|
||||||
|
// depthwise channel-multiplier dim.
|
||||||
|
let w = weight.squeeze(1)?.to_dtype(x.dtype())?.contiguous()?;
|
||||||
|
|
||||||
|
// Decode path: seq_len == 1 AND we have an existing conv_state.
|
||||||
|
// Otherwise (prefill or fresh-start decode), use the full path which
|
||||||
|
// zero-pads on the left internally.
|
||||||
|
if let Some(cs) = conv_state
|
||||||
|
&& seq_len == 1
|
||||||
|
{
|
||||||
|
let cs = cs.contiguous()?;
|
||||||
|
let (output, new_conv_state) =
|
||||||
|
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &cs, conv_kernel_size, true)?;
|
||||||
|
return Ok((output, new_conv_state));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefill / fresh-start: the kernel ignores any prior conv_state and
|
||||||
|
// zero-pads. If we had a non-zero prior state and >1 input tokens
|
||||||
|
// (multi-turn continuation), we'd need to fall back to Rust. Match
|
||||||
|
// mistralrs's behaviour: fresh prefill always.
|
||||||
|
let device = x.device().clone();
|
||||||
|
let zeros_cs = Tensor::zeros((batch_size, conv_dim, conv_kernel_size), x.dtype(), &device)?;
|
||||||
|
let (output, new_conv_state) =
|
||||||
|
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &zeros_cs, conv_kernel_size, false)?;
|
||||||
|
Ok((output, new_conv_state))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fused GDN gating: computes `beta = sigmoid(b)` and
|
||||||
|
/// `g = -exp(a_log) * softplus(a + dt_bias)` together.
|
||||||
|
///
|
||||||
|
/// `b`, `a`: `(B, L, num_heads)` model dtype.
|
||||||
|
/// `a_log`, `dt_bias`: `(num_heads,)` model dtype (cast to f32 internally).
|
||||||
|
///
|
||||||
|
/// Returns `(beta, g)` both in model dtype on the cuda path, both in f32
|
||||||
|
/// on the cpu fallback. The caller casts to f32 before the delta rule.
|
||||||
|
///
|
||||||
|
/// Cuda path: dispatches to `fused_gdn_gating_cuda` — one kernel
|
||||||
|
/// replaces sigmoid + neg(exp) + softplus + broadcast_mul (≈10 candle
|
||||||
|
/// launches per layer).
|
||||||
|
pub(crate) fn run_fused_gating(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
if b.device().is_cuda() {
|
||||||
|
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?.contiguous()?;
|
||||||
|
let dt_bias_f32 = dt_bias.to_dtype(candle_core::DType::F32)?.contiguous()?;
|
||||||
|
return crate::cuda::gdn::fused_gdn_gating_cuda(b, a, &a_log_f32, &dt_bias_f32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run_fused_gating_rust(b, a, a_log, dt_bias)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_fused_gating_rust(
|
||||||
|
b: &Tensor,
|
||||||
|
a: &Tensor,
|
||||||
|
a_log: &Tensor,
|
||||||
|
dt_bias: &Tensor,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let beta = candle_nn::ops::sigmoid(b)?;
|
||||||
|
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let neg_a_exp = a_log_f32.exp()?.neg()?;
|
||||||
|
let dt_b_f32 = dt_bias.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
|
||||||
|
let softplus_val = softplus(&a_plus_dt)?;
|
||||||
|
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
|
||||||
|
let g = neg_a_exp_b.broadcast_mul(&softplus_val)?;
|
||||||
|
Ok((beta, g))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_causal_conv1d_rust(
|
||||||
|
x: &Tensor,
|
||||||
|
weight: &Tensor,
|
||||||
|
conv_state: Option<Tensor>,
|
||||||
|
batch_size: usize,
|
||||||
|
conv_dim: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
conv_kernel_size: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let device = x.device().clone();
|
||||||
|
|
||||||
|
let prepended = match &conv_state {
|
||||||
|
Some(prev) => Tensor::cat(&[prev, x], 2)?,
|
||||||
|
None => x.clone(),
|
||||||
|
};
|
||||||
|
let prep_len = prepended.dims()[2];
|
||||||
|
|
||||||
|
let new_state = if prep_len >= conv_kernel_size {
|
||||||
|
prepended.narrow(2, prep_len - conv_kernel_size, conv_kernel_size)?
|
||||||
|
} else {
|
||||||
|
let pad = Tensor::zeros(
|
||||||
|
(batch_size, conv_dim, conv_kernel_size - prep_len),
|
||||||
|
dtype,
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
Tensor::cat(&[&pad, &prepended], 2)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let conv_out = prepended.conv1d(weight, conv_kernel_size - 1, 1, 1, conv_dim)?;
|
||||||
|
let conv_out = conv_out.narrow(2, 0, prep_len)?;
|
||||||
|
let conv_out = candle_nn::ops::silu(&conv_out)?;
|
||||||
|
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
|
||||||
|
Ok((conv_out, new_state))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
|
||||||
|
/// the standard `[out, in]` order.
|
||||||
|
fn load_linear_no_bias(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
name: &str,
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.pp(name)
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Numerically-stable `softplus(x) = ln(1 + exp(x))`. Matches PyTorch's
|
||||||
|
/// `F.softplus` default (beta=1, threshold=20: for large positive x,
|
||||||
|
/// returns x as-is to avoid overflow in the exp).
|
||||||
|
pub(crate) fn softplus(x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let threshold = 20.0_f64;
|
||||||
|
let big = x.ge(threshold)?; // Tensor<u8> mask
|
||||||
|
let safe = x.minimum(&x.affine(0.0, 0.0)?.affine(0.0, threshold)?)?; // min(x, threshold)
|
||||||
|
let small = ((safe.exp()? + 1.0_f64)?).log()?;
|
||||||
|
// Select x where big, else small.
|
||||||
|
big.where_cond(x, &small)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `repeat_interleave` along a single dim. Candle has no built-in for
|
||||||
|
/// this; emulate with unsqueeze + expand + reshape.
|
||||||
|
pub(crate) fn repeat_interleave(
|
||||||
|
x: &Tensor,
|
||||||
|
repeats: usize,
|
||||||
|
dim: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
if repeats == 1 {
|
||||||
|
return Ok(x.clone());
|
||||||
|
}
|
||||||
|
let mut shape = x.dims().to_vec();
|
||||||
|
let orig = shape[dim];
|
||||||
|
shape.insert(dim + 1, repeats);
|
||||||
|
let mut expanded_shape = shape.clone();
|
||||||
|
expanded_shape[dim + 1] = repeats;
|
||||||
|
let x = x.unsqueeze(dim + 1)?;
|
||||||
|
let x = x.expand(expanded_shape)?;
|
||||||
|
let mut out_shape = x.dims().to_vec();
|
||||||
|
out_shape.remove(dim + 1);
|
||||||
|
out_shape[dim] = orig * repeats;
|
||||||
|
x.contiguous()?.reshape(out_shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use candle_core::{DType, Device};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn softplus_small_x() {
|
||||||
|
// softplus(0) = ln(2) ≈ 0.6931
|
||||||
|
let x = Tensor::new(&[0.0_f32], &Device::Cpu).unwrap();
|
||||||
|
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
|
||||||
|
assert!((out[0] - 2.0_f32.ln()).abs() < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn softplus_large_x_returns_x() {
|
||||||
|
// For x = 30, softplus(x) ≈ x (the threshold branch).
|
||||||
|
let x = Tensor::new(&[30.0_f32], &Device::Cpu).unwrap();
|
||||||
|
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
|
||||||
|
assert!((out[0] - 30.0).abs() < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn repeat_interleave_doubles_dim() {
|
||||||
|
let x = Tensor::new(&[[1.0_f32, 2.0], [3.0, 4.0]], &Device::Cpu).unwrap(); // shape (2, 2)
|
||||||
|
let out = repeat_interleave(&x, 2, 1).unwrap(); // each col duplicated
|
||||||
|
let v: Vec<Vec<f32>> = out.to_vec2().unwrap();
|
||||||
|
// Row 0: 1, 1, 2, 2
|
||||||
|
// Row 1: 3, 3, 4, 4
|
||||||
|
assert_eq!(v[0], vec![1.0, 1.0, 2.0, 2.0]);
|
||||||
|
assert_eq!(v[1], vec![3.0, 3.0, 4.0, 4.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sanity: the recurrent path produces a finite tensor of the right
|
||||||
|
/// shape on tiny dimensions. Doesn't validate numerical correctness
|
||||||
|
/// against the Python reference — that would need a fixed-weight
|
||||||
|
/// fixture to compare against. Catches structural mistakes
|
||||||
|
/// (broadcasting shapes, off-by-one slices) early.
|
||||||
|
#[test]
|
||||||
|
fn forward_smoke_with_tiny_dimensions() {
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
let dtype = DType::F32;
|
||||||
|
let (b, l) = (1, 3);
|
||||||
|
let cfg = TextConfig {
|
||||||
|
vocab_size: 100,
|
||||||
|
hidden_size: 16,
|
||||||
|
intermediate_size: 32,
|
||||||
|
num_hidden_layers: 1,
|
||||||
|
num_attention_heads: 4,
|
||||||
|
num_key_value_heads: 1,
|
||||||
|
head_dim: 4,
|
||||||
|
max_position_embeddings: 32,
|
||||||
|
rope_parameters: RopeParameters {
|
||||||
|
rope_theta: 10000.0,
|
||||||
|
partial_rotary_factor: 1.0,
|
||||||
|
rope_type: None,
|
||||||
|
},
|
||||||
|
rms_norm_eps: 1e-6,
|
||||||
|
tie_word_embeddings: false,
|
||||||
|
attn_output_gate: true,
|
||||||
|
layer_types: vec!["linear_attention".into()],
|
||||||
|
full_attention_interval: Some(4),
|
||||||
|
hidden_act: "silu".into(),
|
||||||
|
linear_num_value_heads: 4,
|
||||||
|
linear_num_key_heads: 2,
|
||||||
|
linear_key_head_dim: 4,
|
||||||
|
linear_value_head_dim: 4,
|
||||||
|
linear_conv_kernel_dim: 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build a synthetic VarBuilder with all-zeros weights.
|
||||||
|
// Easier path: skip the load and construct GatedDeltaNet
|
||||||
|
// manually by hand-rolling the Linear/Tensor inputs.
|
||||||
|
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &dev).unwrap();
|
||||||
|
let key_dim = cfg.linear_key_head_dim * cfg.linear_num_key_heads;
|
||||||
|
let value_dim = cfg.linear_value_head_dim * cfg.linear_num_value_heads;
|
||||||
|
let conv_dim = key_dim * 2 + value_dim;
|
||||||
|
let mut net = GatedDeltaNet {
|
||||||
|
in_proj_qkv: Linear::new(zeros(&[conv_dim, cfg.hidden_size]), None),
|
||||||
|
in_proj_z: Linear::new(zeros(&[value_dim, cfg.hidden_size]), None),
|
||||||
|
in_proj_b: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
|
||||||
|
in_proj_a: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
|
||||||
|
out_proj: Linear::new(zeros(&[cfg.hidden_size, value_dim]), None),
|
||||||
|
conv1d_weight: zeros(&[conv_dim, 1, cfg.linear_conv_kernel_dim]),
|
||||||
|
dt_bias: zeros(&[cfg.linear_num_value_heads]),
|
||||||
|
a_log: zeros(&[cfg.linear_num_value_heads]),
|
||||||
|
norm: {
|
||||||
|
let weight = Tensor::ones(&[cfg.linear_value_head_dim], dtype, &dev).unwrap();
|
||||||
|
Qwen3_5RmsNormGated::from_weight(weight, cfg.rms_norm_eps)
|
||||||
|
},
|
||||||
|
num_v_heads: cfg.linear_num_value_heads,
|
||||||
|
num_k_heads: cfg.linear_num_key_heads,
|
||||||
|
head_k_dim: cfg.linear_key_head_dim,
|
||||||
|
head_v_dim: cfg.linear_value_head_dim,
|
||||||
|
key_dim,
|
||||||
|
value_dim,
|
||||||
|
conv_dim,
|
||||||
|
conv_kernel_size: cfg.linear_conv_kernel_dim,
|
||||||
|
state: GatedDeltaNetState::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let x = Tensor::ones(&[b, l, cfg.hidden_size], dtype, &dev).unwrap();
|
||||||
|
let y = net.forward(&x).unwrap();
|
||||||
|
assert_eq!(y.dims(), &[b, l, cfg.hidden_size]);
|
||||||
|
// All zero weights → output should be zero. Confirms no NaN/Inf
|
||||||
|
// poisoning from the f32 promotions.
|
||||||
|
let v: Vec<f32> = y.flatten_all().unwrap().to_vec1().unwrap();
|
||||||
|
assert!(v.iter().all(|x| x.is_finite()));
|
||||||
|
}
|
||||||
|
}
|
||||||
53
crates/neuron/src/harness/arch/qwen3_5/mlp.rs
Normal file
53
crates/neuron/src/harness/arch/qwen3_5/mlp.rs
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
//! SwiGLU MLP block for Qwen3-Next.
|
||||||
|
//!
|
||||||
|
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
|
||||||
|
//! no bias on any of the three projections.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
|
||||||
|
pub struct Qwen3_5MLP {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5MLP {
|
||||||
|
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let h = cfg.hidden_size;
|
||||||
|
let i = cfg.intermediate_size;
|
||||||
|
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
|
||||||
|
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
|
||||||
|
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj,
|
||||||
|
up_proj,
|
||||||
|
down_proj,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3_5MLP {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
|
||||||
|
let rhs = self.up_proj.forward(x)?;
|
||||||
|
self.down_proj.forward(&(lhs * rhs)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_linear_no_bias(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
name: &str,
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.pp(name)
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
397
crates/neuron/src/harness/arch/qwen3_5/mod.rs
Normal file
397
crates/neuron/src/harness/arch/qwen3_5/mod.rs
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
//! Qwen3-Next (`model_type = "qwen3_5"`) architecture — Qwen3.6's
|
||||||
|
//! upstream architecture revision.
|
||||||
|
//!
|
||||||
|
//! ## Naming
|
||||||
|
//!
|
||||||
|
//! The model release this targets is `Qwen/Qwen3.6-*` but the
|
||||||
|
//! architecture name in HuggingFace's `config.json` is `qwen3_5`.
|
||||||
|
//! mistralrs calls the same architecture `qwen3_next`; that label
|
||||||
|
//! ages poorly the next time Qwen ship a new arch, so we key on the
|
||||||
|
//! canonical `qwen3_5` from the model's own config.
|
||||||
|
//!
|
||||||
|
//! ## Status
|
||||||
|
//!
|
||||||
|
//! **Single-GPU dense path is real**. Both attention flavours
|
||||||
|
//! (`full_attention` with the output-gated GQA causal attention and
|
||||||
|
//! `linear_attention` with the Gated DeltaNet recurrent block) are
|
||||||
|
//! implemented. The model loads from upstream safetensors via the
|
||||||
|
//! existing `load_arch_dense` dispatch and runs forward end to end.
|
||||||
|
//!
|
||||||
|
//! Numerical correctness vs the reference Python is **not yet
|
||||||
|
//! validated** — the structural code path is right, weight tensor
|
||||||
|
//! names match the upstream layout, shapes flow through cleanly, but
|
||||||
|
//! the Tbilisi probe (and any other downstream test) is the next
|
||||||
|
//! step. Likely places a bug would surface:
|
||||||
|
//! - Per-rank vs per-token-position offsets in the recurrent delta
|
||||||
|
//! rule (`linear_attn.rs`).
|
||||||
|
//! - Off-by-one in the conv state continuation across decode steps.
|
||||||
|
//! - RoPE phase mismatch from MRoPE simplification (we treat the
|
||||||
|
//! three position grids as collapsed, which is correct only for
|
||||||
|
//! text-only inference).
|
||||||
|
//!
|
||||||
|
//! ## Submodules
|
||||||
|
//!
|
||||||
|
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
|
||||||
|
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
|
||||||
|
//! `l2norm` helper.
|
||||||
|
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
|
||||||
|
//! rotate-half).
|
||||||
|
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
|
||||||
|
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
|
||||||
|
//! widening on `q_proj`.
|
||||||
|
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
|
||||||
|
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
|
||||||
|
//! delta rule → RMSNormGated → out_proj).
|
||||||
|
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
|
||||||
|
//! two attention flavours per layer index.
|
||||||
|
//!
|
||||||
|
//! ## Open work
|
||||||
|
//!
|
||||||
|
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
|
||||||
|
//! Sharding strategy diverges by layer type:
|
||||||
|
//! - Full-attention layers: column-parallel q/k/v (including the
|
||||||
|
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
|
||||||
|
//! `tp_qwen3.rs`.
|
||||||
|
//! - Linear-attention layers: the recurrent state is per-V-head, so
|
||||||
|
//! V-head-dimension sharding works cleanly — split `num_v_heads`
|
||||||
|
//! across ranks (`num_v_heads / world_size` per rank), shard
|
||||||
|
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
|
||||||
|
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
|
||||||
|
//! `dt_bias` per-head params shard with the heads.
|
||||||
|
//!
|
||||||
|
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
|
||||||
|
//! per-token recurrent path for prefill too — correct but O(L).
|
||||||
|
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
|
||||||
|
//! prefill substantially with no surface change.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||||
|
use candle_nn::Embedding;
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub mod decoder;
|
||||||
|
pub mod full_attn;
|
||||||
|
pub mod linear_attn;
|
||||||
|
pub mod mlp;
|
||||||
|
pub mod rmsnorm;
|
||||||
|
pub mod rope;
|
||||||
|
|
||||||
|
use decoder::Qwen3_5DecoderLayer;
|
||||||
|
use rmsnorm::Qwen3_5RmsNorm;
|
||||||
|
use rope::RotaryEmbedding;
|
||||||
|
|
||||||
|
/// `model_type` we deserialise from `config.json`. Const so the
|
||||||
|
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
|
||||||
|
/// magic strings.
|
||||||
|
pub const MODEL_TYPE: &str = "qwen3_5";
|
||||||
|
|
||||||
|
/// Top-level shape of Qwen3-Next's `config.json`. The real
|
||||||
|
/// hyperparameters live in `text_config`; the rest is multimodal /
|
||||||
|
/// tokeniser glue we don't need for the language-model forward.
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
/// Always `"qwen3_5"` for this architecture. Kept on the struct
|
||||||
|
/// so the (eventual) dispatch / logging code can show it without
|
||||||
|
/// re-parsing the JSON.
|
||||||
|
pub model_type: String,
|
||||||
|
/// The text-side hyperparameters. Everything we actually need.
|
||||||
|
pub text_config: TextConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
||||||
|
/// but with the extras Qwen3-Next adds (`attn_output_gate`,
|
||||||
|
/// `layer_types`, `full_attention_interval`, larger `head_dim`).
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct TextConfig {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub head_dim: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
/// Nested RoPE settings. Qwen3-Next puts `rope_theta` and
|
||||||
|
/// `partial_rotary_factor` inside this block rather than at the
|
||||||
|
/// top level — important because the partial rotary means only
|
||||||
|
/// `head_dim * partial_rotary_factor` dims get RoPE applied (the
|
||||||
|
/// rest pass through unchanged).
|
||||||
|
pub rope_parameters: RopeParameters,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
|
|
||||||
|
/// New in Qwen3-Next: a sigmoid gate multiplied into the attention
|
||||||
|
/// output before the o_proj. The Python reference applies it
|
||||||
|
/// pointwise after softmax+matmul.
|
||||||
|
#[serde(default)]
|
||||||
|
pub attn_output_gate: bool,
|
||||||
|
|
||||||
|
/// One entry per decoder layer; values are `"full_attention"` or
|
||||||
|
/// `"linear_attention"`. Length must equal `num_hidden_layers`.
|
||||||
|
/// `full_attention_interval` is a derived hint (every 4th layer
|
||||||
|
/// by default) — `layer_types` is authoritative.
|
||||||
|
#[serde(default)]
|
||||||
|
pub layer_types: Vec<String>,
|
||||||
|
|
||||||
|
/// Hint for the layer-type pattern (defaults to 4). Kept for
|
||||||
|
/// logging / validation; the forward dispatches on `layer_types`.
|
||||||
|
#[serde(default)]
|
||||||
|
pub full_attention_interval: Option<usize>,
|
||||||
|
|
||||||
|
/// Hidden activation (`"silu"` for Qwen3-Next). Used by the MLP
|
||||||
|
/// and the linear-attention conv1d.
|
||||||
|
#[serde(default = "default_hidden_act")]
|
||||||
|
pub hidden_act: String,
|
||||||
|
|
||||||
|
// --- Gated DeltaNet (linear-attention) hyperparams -----------------
|
||||||
|
/// Per-layer linear-attention V-head count (Qwen3.6-27B: 48).
|
||||||
|
/// More V-heads than K-heads is fine — query/key get
|
||||||
|
/// `repeat_interleave`'d to match before the delta rule.
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_num_value_heads: usize,
|
||||||
|
/// Per-layer linear-attention K-head count (Qwen3.6-27B: 16).
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_num_key_heads: usize,
|
||||||
|
/// Per-head key dimension for the linear-attention path
|
||||||
|
/// (Qwen3.6-27B: 128). Separate from `head_dim` which the
|
||||||
|
/// full-attention layers use.
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_key_head_dim: usize,
|
||||||
|
/// Per-head value dimension for the linear-attention path
|
||||||
|
/// (Qwen3.6-27B: 128).
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_value_head_dim: usize,
|
||||||
|
/// Causal Conv1d kernel size used before the delta rule
|
||||||
|
/// (Qwen3.6-27B: 4).
|
||||||
|
#[serde(default)]
|
||||||
|
pub linear_conv_kernel_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_hidden_act() -> String {
|
||||||
|
"silu".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
|
||||||
|
/// `mrope_section` and `mrope_interleaved` are accepted via the
|
||||||
|
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat
|
||||||
|
/// MRoPE as plain RoPE for text-only inference (the three position
|
||||||
|
/// grids carry identical ids when there's no vision input, so the
|
||||||
|
/// interleaving is a no-op).
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct RopeParameters {
|
||||||
|
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
|
||||||
|
#[serde(default = "default_rope_theta")]
|
||||||
|
pub rope_theta: f64,
|
||||||
|
/// Fraction of `head_dim` that gets the rotation applied. The
|
||||||
|
/// remaining `head_dim * (1 - partial_rotary_factor)` dims pass
|
||||||
|
/// through unchanged. Qwen3.6 / Qwen3.5: 0.25.
|
||||||
|
#[serde(default = "default_partial_rotary_factor")]
|
||||||
|
pub partial_rotary_factor: f32,
|
||||||
|
/// `"default"` for the standard inv_freq RoPE; other values (e.g.
|
||||||
|
/// `"linear"`, `"dynamic"`) are upstream-supported but not yet
|
||||||
|
/// implemented here.
|
||||||
|
#[serde(default)]
|
||||||
|
pub rope_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_rope_theta() -> f64 {
|
||||||
|
10_000.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_partial_rotary_factor() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Qwen3-Next base transformer (embedding + decoder stack + final
|
||||||
|
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
|
||||||
|
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
|
||||||
|
/// loaded handle.
|
||||||
|
pub struct Qwen3_5Model {
|
||||||
|
embed_tokens: Embedding,
|
||||||
|
layers: Vec<Qwen3_5DecoderLayer>,
|
||||||
|
norm: Qwen3_5RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5Model {
|
||||||
|
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
|
||||||
|
// Qwen3-Next is a multimodal architecture whose text core lives
|
||||||
|
// under `model.language_model.*` — sibling to `model.visual.*`
|
||||||
|
// (the vision tower) and to top-level `lm_head` / `mtp.*`.
|
||||||
|
// Every text-side tensor in the safetensors files is under
|
||||||
|
// this prefix; we ignore the vision and MTP weights for
|
||||||
|
// language-model inference.
|
||||||
|
let text_vb = vb.pp("model.language_model");
|
||||||
|
|
||||||
|
let embed_vb = text_vb.pp("embed_tokens");
|
||||||
|
let embed_weight = embed_vb
|
||||||
|
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||||
|
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
|
||||||
|
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||||
|
|
||||||
|
if cfg.layer_types.len() != cfg.num_hidden_layers {
|
||||||
|
anyhow::bail!(
|
||||||
|
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
|
||||||
|
got {}",
|
||||||
|
cfg.num_hidden_layers,
|
||||||
|
cfg.layer_types.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let vb_l = text_vb.pp("layers");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(Qwen3_5DecoderLayer::load(
|
||||||
|
cfg,
|
||||||
|
rotary.clone(),
|
||||||
|
i,
|
||||||
|
&vb_l.pp(i),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_weight(&self) -> &Tensor {
|
||||||
|
self.embed_tokens.embeddings()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for l in &mut self.layers {
|
||||||
|
l.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
// Causal mask only needed for L > 1 prefill; full-attention
|
||||||
|
// layers consume it via broadcast_add. Linear-attention layers
|
||||||
|
// ignore the mask.
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
|
};
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Qwen3_5ForCausalLM {
|
||||||
|
base: Qwen3_5Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5ForCausalLM {
|
||||||
|
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let cfg = &config.text_config;
|
||||||
|
let base = Qwen3_5Model::load(cfg, &vb)?;
|
||||||
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::new(base.embed_weight().clone(), None)
|
||||||
|
} else {
|
||||||
|
let weight = vb
|
||||||
|
.pp("lm_head")
|
||||||
|
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||||
|
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
||||||
|
Linear::new(weight, None)
|
||||||
|
};
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
||||||
|
/// the last position, shape `(B, 1, vocab_size)` — same contract
|
||||||
|
/// as `qwen3::ModelForCausalLM::forward` so the harness's
|
||||||
|
/// `squeeze_to_vocab` helper handles both uniformly.
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
let hidden = self.base.forward(input, offset)?;
|
||||||
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// Confirms we can deserialise the real upstream config shape.
|
||||||
|
/// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to
|
||||||
|
/// the fields the architecture cares about. Note `rope_theta` and
|
||||||
|
/// `partial_rotary_factor` are nested under `rope_parameters` —
|
||||||
|
/// Qwen3-Next does NOT have a top-level `rope_theta`.
|
||||||
|
#[test]
|
||||||
|
fn config_deserialises_the_real_qwen3_6_shape() {
|
||||||
|
let raw = r#"{
|
||||||
|
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||||
|
"model_type": "qwen3_5",
|
||||||
|
"image_token_id": 248056,
|
||||||
|
"language_model_only": false,
|
||||||
|
"text_config": {
|
||||||
|
"vocab_size": 248064,
|
||||||
|
"hidden_size": 5120,
|
||||||
|
"intermediate_size": 17408,
|
||||||
|
"num_hidden_layers": 64,
|
||||||
|
"num_attention_heads": 64,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 256,
|
||||||
|
"max_position_embeddings": 32768,
|
||||||
|
"rope_parameters": {
|
||||||
|
"mrope_interleaved": true,
|
||||||
|
"mrope_section": [11, 11, 10],
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"rope_theta": 10000000,
|
||||||
|
"rope_type": "default"
|
||||||
|
},
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"attn_output_gate": true,
|
||||||
|
"full_attention_interval": 4,
|
||||||
|
"layer_types": [
|
||||||
|
"linear_attention", "linear_attention",
|
||||||
|
"linear_attention", "full_attention"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
let cfg: Config = serde_json::from_str(raw).expect("parse Qwen3.6 config");
|
||||||
|
assert_eq!(cfg.model_type, "qwen3_5");
|
||||||
|
assert_eq!(cfg.text_config.hidden_size, 5120);
|
||||||
|
assert_eq!(cfg.text_config.head_dim, 256);
|
||||||
|
assert!(cfg.text_config.attn_output_gate);
|
||||||
|
assert_eq!(cfg.text_config.full_attention_interval, Some(4));
|
||||||
|
assert_eq!(cfg.text_config.layer_types.len(), 4);
|
||||||
|
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
|
||||||
|
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
161
crates/neuron/src/harness/arch/qwen3_5/rmsnorm.rs
Normal file
161
crates/neuron/src/harness/arch/qwen3_5/rmsnorm.rs
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
//! Norm primitives for Qwen3-Next.
|
||||||
|
//!
|
||||||
|
//! Two reasons we can't reuse `candle_nn::RmsNorm` directly:
|
||||||
|
//!
|
||||||
|
//! 1. **`(1.0 + weight)` scaling.** Qwen3-Next's `Qwen3_5RMSNorm`
|
||||||
|
//! initialises `weight` to zeros and applies `(1.0 + weight)` to
|
||||||
|
//! the normalised vector. `candle_nn::RmsNorm` applies `weight`
|
||||||
|
//! directly. The two are equivalent only when the operator has
|
||||||
|
//! pre-shifted the weights — the upstream checkpoints have not. See
|
||||||
|
//! `huggingface/transformers#29402` for the upstream PR that
|
||||||
|
//! introduced the `(1 + w)` form to recover from the zero-init.
|
||||||
|
//!
|
||||||
|
//! 2. **Gated variant.** The linear-attention layer post-normalises
|
||||||
|
//! its output by an RMSNorm *gated* with a per-element SiLU on
|
||||||
|
//! a sibling `z` projection — fused for numerical reasons (the
|
||||||
|
//! norm's float32 promotion has to happen before the SiLU
|
||||||
|
//! multiply). Not a single existing candle op.
|
||||||
|
//!
|
||||||
|
//! Both ops accept inputs in any compute dtype; promotion to f32 for
|
||||||
|
//! the variance calculation matches the Python reference.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{D, Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
|
||||||
|
/// L2-normalise along the last dim with a small epsilon. Matches the
|
||||||
|
/// `l2norm` helper in `transformers/models/qwen3_5/modeling_qwen3_5.py`
|
||||||
|
/// — `x * rsqrt(sum(x*x) + eps)`. The linear-attention path uses this
|
||||||
|
/// on Q and K before the delta rule when
|
||||||
|
/// `use_qk_l2norm_in_kernel=True` (which Qwen3-Next always sets).
|
||||||
|
pub fn l2norm(x: &Tensor, eps: f32) -> candle_core::Result<Tensor> {
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let sq = x_f32.sqr()?;
|
||||||
|
let sum = sq.sum_keepdim(D::Minus1)?;
|
||||||
|
let inv = (sum + eps as f64)?.sqrt()?.recip()?;
|
||||||
|
x_f32.broadcast_mul(&inv)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Qwen3-Next's RMSNorm. Stores the raw weight tensor; forward applies
|
||||||
|
/// `(1.0 + weight) * x_normed`.
|
||||||
|
pub struct Qwen3_5RmsNorm {
|
||||||
|
weight: Tensor,
|
||||||
|
eps: f32,
|
||||||
|
size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5RmsNorm {
|
||||||
|
/// Load `weight` from the ShardedVarBuilder. `vb` should already be
|
||||||
|
/// `.pp(...)`-ed to the norm's tensor prefix.
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get(size, "weight")
|
||||||
|
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
eps: eps as f32,
|
||||||
|
size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn size(&self) -> usize {
|
||||||
|
self.size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3_5RmsNorm {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||||
|
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
|
||||||
|
// Promote weight to f32 and shift by 1.0 *before* multiplying.
|
||||||
|
// Doing the (1 + w) operation in fp16 lands at -inf for the
|
||||||
|
// bottom-of-range weights at load time.
|
||||||
|
let w_f32 = self.weight.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let scale = (w_f32 + 1.0_f64)?;
|
||||||
|
normed.broadcast_mul(&scale)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gated RMSNorm used at the tail of `Qwen3_5GatedDeltaNet`. Equivalent
|
||||||
|
/// to `x_normed * weight * silu(gate)` but with both the norm and the
|
||||||
|
/// gate evaluated in float32 to avoid mid-pipeline underflow.
|
||||||
|
///
|
||||||
|
/// Note: unlike `Qwen3_5RmsNorm`, this variant matches the Python
|
||||||
|
/// reference's `Qwen3_5RMSNormGated` which uses `weight` directly (not
|
||||||
|
/// `1.0 + weight`).
|
||||||
|
pub struct Qwen3_5RmsNormGated {
|
||||||
|
weight: Tensor,
|
||||||
|
eps: f32,
|
||||||
|
size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3_5RmsNormGated {
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get(size, "weight")
|
||||||
|
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
eps: eps as f32,
|
||||||
|
size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Direct constructor — used by unit tests that build a layer
|
||||||
|
/// without going through a VarBuilder.
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
|
||||||
|
let size = weight.dims()[0];
|
||||||
|
Self {
|
||||||
|
weight,
|
||||||
|
eps: eps as f32,
|
||||||
|
size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn size(&self) -> usize {
|
||||||
|
self.size
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `x` and `gate` share the same last-dim shape (`size`).
|
||||||
|
pub fn forward(&self, x: &Tensor, gate: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let dtype = x.dtype();
|
||||||
|
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||||
|
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
|
||||||
|
let w = self.weight.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let out = normed.broadcast_mul(&w)?;
|
||||||
|
// SiLU on the float32 gate, multiply back into the normed
|
||||||
|
// tensor, then cast to the model dtype.
|
||||||
|
let g = gate.to_dtype(candle_core::DType::F32)?;
|
||||||
|
let silu_gate = candle_nn::ops::silu(&g)?;
|
||||||
|
(out * silu_gate)?.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use candle_core::Device;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn l2norm_matches_hand_calc() {
|
||||||
|
let x = Tensor::new(&[3.0_f32, 4.0_f32], &Device::Cpu).unwrap();
|
||||||
|
let out = l2norm(&x, 1e-6).unwrap();
|
||||||
|
let v: Vec<f32> = out.to_vec1().unwrap();
|
||||||
|
// |x| = 5, so x/|x| = [0.6, 0.8] (eps is tiny).
|
||||||
|
assert!((v[0] - 0.6).abs() < 1e-4);
|
||||||
|
assert!((v[1] - 0.8).abs() < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn l2norm_zero_vector_is_safe_via_epsilon() {
|
||||||
|
let x = Tensor::new(&[0.0_f32, 0.0_f32], &Device::Cpu).unwrap();
|
||||||
|
let out = l2norm(&x, 1e-6).unwrap();
|
||||||
|
let v: Vec<f32> = out.to_vec1().unwrap();
|
||||||
|
assert!(v.iter().all(|x| x.is_finite()));
|
||||||
|
}
|
||||||
|
}
|
||||||
114
crates/neuron/src/harness/arch/qwen3_5/rope.rs
Normal file
114
crates/neuron/src/harness/arch/qwen3_5/rope.rs
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
||||||
|
//!
|
||||||
|
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
|
||||||
|
//! reference Python — three position grids interleaved per
|
||||||
|
//! `mrope_section`. For text-only inference all three grids carry the
|
||||||
|
//! same position ids and the interleave is a no-op, so this module
|
||||||
|
//! implements the plain (non-mrope) flavour: the standard inv_freq
|
||||||
|
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
|
||||||
|
//!
|
||||||
|
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
|
||||||
|
//! head dim is negated and swapped into the first). The reference
|
||||||
|
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
|
||||||
|
//! `rope_slow` is the matching helper.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
|
||||||
|
use super::TextConfig;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
/// Number of dims at the head's leading edge that the rotation
|
||||||
|
/// covers. The remaining `head_dim - rotary_dim` dims pass through
|
||||||
|
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
|
||||||
|
/// for `head_dim = 256` only 64 dims rotate.
|
||||||
|
rotary_dim: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let rope = &cfg.rope_parameters;
|
||||||
|
let rotary_dim = (head_dim as f32 * rope.partial_rotary_factor) as usize;
|
||||||
|
if !rotary_dim.is_multiple_of(2) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"rotary_dim = head_dim * partial_rotary_factor = {head_dim} * {} = {rotary_dim} \
|
||||||
|
must be even (cos/sin are paired)",
|
||||||
|
rope.partial_rotary_factor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if rotary_dim == 0 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"rotary_dim = 0 (partial_rotary_factor = {} too small)",
|
||||||
|
rope.partial_rotary_factor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<f32> = (0..rotary_dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let n = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
|
rotary_dim,
|
||||||
|
head_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply RoPE to q, k.
|
||||||
|
///
|
||||||
|
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
|
||||||
|
/// into the cached cos/sin table — the position of the first token
|
||||||
|
/// in the current step.
|
||||||
|
///
|
||||||
|
/// When `rotary_dim < head_dim` the rotation is applied only to the
|
||||||
|
/// first `rotary_dim` dims of each head; the tail passes through
|
||||||
|
/// unchanged (matches the reference Python's
|
||||||
|
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`).
|
||||||
|
pub fn apply(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let (_, _, seq_len, head_dim_in) = q.dims4()?;
|
||||||
|
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
|
||||||
|
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||||
|
if self.rotary_dim == self.head_dim {
|
||||||
|
// Full rotation.
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
} else {
|
||||||
|
// Partial rotation: narrow → rotate → cat the untouched tail.
|
||||||
|
let tail = self.head_dim - self.rotary_dim;
|
||||||
|
let q_rot = q
|
||||||
|
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||||
|
.contiguous()?;
|
||||||
|
let q_pass = q.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||||
|
let k_rot = k
|
||||||
|
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||||
|
.contiguous()?;
|
||||||
|
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||||
|
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?;
|
||||||
|
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?;
|
||||||
|
let q_embed =
|
||||||
|
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||||
|
let k_embed =
|
||||||
|
Tensor::cat(&[&k_rotated, &k_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
4026
crates/neuron/src/harness/candle.rs
Normal file
4026
crates/neuron/src/harness/candle.rs
Normal file
File diff suppressed because it is too large
Load Diff
392
crates/neuron/src/harness/chat_template.rs
Normal file
392
crates/neuron/src/harness/chat_template.rs
Normal file
@@ -0,0 +1,392 @@
|
|||||||
|
//! Chat-template rendering for the model-supplied Jinja templates
|
||||||
|
//! HuggingFace tokenizers ship in `tokenizer_config.json`.
|
||||||
|
//!
|
||||||
|
//! ## Background
|
||||||
|
//!
|
||||||
|
//! Every modern open-weight model bundles a `chat_template` field
|
||||||
|
//! in its `tokenizer_config.json` — a Jinja2 template string that
|
||||||
|
//! converts a sequence of `{role, content}` messages into the
|
||||||
|
//! exact prompt the model was trained on. Examples:
|
||||||
|
//!
|
||||||
|
//! - Qwen3-Coder: `<|im_start|>{role}\n{content}<|im_end|>\n…`
|
||||||
|
//! with conditional `enable_thinking` handling that injects an
|
||||||
|
//! empty `<think>\n\n</think>` block when set false.
|
||||||
|
//! - DeepSeek-R1: similar im_start framing with different special-
|
||||||
|
//! token names.
|
||||||
|
//! - Mistral / Magistral: a `[INST]` / `[/INST]` framing.
|
||||||
|
//! - Claude / Llama: another shape again.
|
||||||
|
//!
|
||||||
|
//! Rendering the model's own template is the only way to get the
|
||||||
|
//! *exact* prompt format the model was trained on plus the
|
||||||
|
//! model-specific kwargs (`enable_thinking`, `tools`, …) without
|
||||||
|
//! hardcoding per-model logic. The alternative — neuron's previous
|
||||||
|
//! `format_qwen3_prompt` — was a hardcoded Qwen3 ChatML glue that
|
||||||
|
//! ignored kwargs entirely.
|
||||||
|
//!
|
||||||
|
//! ## Scope
|
||||||
|
//!
|
||||||
|
//! This module is request-side only: it builds the prompt string
|
||||||
|
//! the tokenizer ingests before inference. The reasoning- and
|
||||||
|
//! tool-call-marker token routing (issues #6, #8) is response-side
|
||||||
|
//! and stays in `wire::openai_chat` / the streaming inference
|
||||||
|
//! loops.
|
||||||
|
//!
|
||||||
|
//! ## Fallback
|
||||||
|
//!
|
||||||
|
//! When the model's `tokenizer_config.json` is missing, doesn't
|
||||||
|
//! parse, lacks a `chat_template`, or renders an error, the caller
|
||||||
|
//! falls back to `format_qwen3_prompt`. The
|
||||||
|
//! `NEURON_USE_CHAT_TEMPLATE=false` env var is a global kill
|
||||||
|
//! switch — if a deploy goes sideways and the renderer is to
|
||||||
|
//! blame, an operator can flip the env and restart neuron without
|
||||||
|
//! shipping a new build.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use cortex_core::openai::{ChatMessage, MessageContent};
|
||||||
|
use minijinja::Environment;
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// Environment variable that, when set to `false`/`0`/`no`,
|
||||||
|
/// forces every model to skip its `chat_template` and fall back
|
||||||
|
/// to `format_qwen3_prompt`. Default (unset) is "use chat
|
||||||
|
/// templates where available".
|
||||||
|
pub const KILL_SWITCH_ENV: &str = "NEURON_USE_CHAT_TEMPLATE";
|
||||||
|
|
||||||
|
/// Read the global kill switch. `true` means chat templates are
|
||||||
|
/// enabled; `false` forces the fallback path everywhere.
|
||||||
|
pub fn chat_templates_enabled() -> bool {
|
||||||
|
match std::env::var(KILL_SWITCH_ENV).ok().as_deref() {
|
||||||
|
Some(s) => !matches!(
|
||||||
|
s.trim().to_ascii_lowercase().as_str(),
|
||||||
|
"false" | "0" | "no" | "off"
|
||||||
|
),
|
||||||
|
None => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience: probe for `tokenizer_config.json` in the same
|
||||||
|
/// directory the tokenizer was loaded from. Both files come from
|
||||||
|
/// the same HuggingFace snapshot in the hf-hub cache, so the
|
||||||
|
/// sibling path is reliable.
|
||||||
|
pub fn load_chat_template_alongside(tokenizer_json_path: &Path) -> Option<String> {
|
||||||
|
let parent = tokenizer_json_path.parent()?;
|
||||||
|
let config_path = parent.join("tokenizer_config.json");
|
||||||
|
load_chat_template_from(&config_path)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Best-effort load of `chat_template` from a HuggingFace
|
||||||
|
/// `tokenizer_config.json`. Returns `None` when the file is
|
||||||
|
/// absent, doesn't parse, or lacks the `chat_template` field —
|
||||||
|
/// in all of those cases the caller falls back to
|
||||||
|
/// `format_qwen3_prompt`. Warnings are logged so an operator can
|
||||||
|
/// see why the fallback fired.
|
||||||
|
pub fn load_chat_template_from(path: &Path) -> Option<String> {
|
||||||
|
let text = match std::fs::read_to_string(path) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(
|
||||||
|
path = %path.display(),
|
||||||
|
error = %e,
|
||||||
|
"chat_template: tokenizer_config.json absent or unreadable; falling back"
|
||||||
|
);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let value: Value = match serde_json::from_str(&text) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
path = %path.display(),
|
||||||
|
error = %e,
|
||||||
|
"chat_template: tokenizer_config.json failed to parse; falling back"
|
||||||
|
);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Some tokenizer_config.json files carry `chat_template` as an
|
||||||
|
// array of `{name, template}` objects (multi-template models —
|
||||||
|
// tool-use variant, default variant). For now we pick the first
|
||||||
|
// entry; future iterations could honour a name hint.
|
||||||
|
match value.get("chat_template") {
|
||||||
|
Some(Value::String(s)) => Some(s.clone()),
|
||||||
|
Some(Value::Array(arr)) => {
|
||||||
|
for entry in arr {
|
||||||
|
if let Some(t) = entry.get("template").and_then(|v| v.as_str()) {
|
||||||
|
return Some(t.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::warn!(
|
||||||
|
path = %path.display(),
|
||||||
|
"chat_template: array form had no usable template entry; falling back"
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Render the chat template into the prompt the model expects.
|
||||||
|
///
|
||||||
|
/// `template` is the raw Jinja string from `tokenizer_config.json`.
|
||||||
|
/// `messages` is the conversation in order. `kwargs` is the
|
||||||
|
/// `chat_template_kwargs` object the client supplied on the
|
||||||
|
/// request (or `Value::Null` when absent). The function expands
|
||||||
|
/// the kwargs into the Jinja context alongside the standard
|
||||||
|
/// `messages` and `add_generation_prompt` variables HF templates
|
||||||
|
/// expect.
|
||||||
|
///
|
||||||
|
/// `tools` is the request's `tools` array (or `Value::Null`).
|
||||||
|
/// Some chat templates iterate it to emit native tool definitions
|
||||||
|
/// (Qwen3-Coder's tool-use template, Mistral's [TOOL_DEFINITIONS]
|
||||||
|
/// frame). We forward whatever the client sent without
|
||||||
|
/// interpretation.
|
||||||
|
pub fn render_chat_template(
|
||||||
|
template: &str,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
tools: &Value,
|
||||||
|
kwargs: &Value,
|
||||||
|
) -> Result<String> {
|
||||||
|
let mut env = Environment::new();
|
||||||
|
// Compile the template against a fixed name so error messages
|
||||||
|
// surface "chat_template" rather than `<template>`.
|
||||||
|
env.add_template("chat_template", template)
|
||||||
|
.context("compile chat_template")?;
|
||||||
|
let tmpl = env.get_template("chat_template").unwrap();
|
||||||
|
|
||||||
|
// Convert our internal ChatMessage shape into the
|
||||||
|
// `[{role, content}]` shape HF templates iterate. Text content
|
||||||
|
// becomes a string; Parts becomes an array of content blocks.
|
||||||
|
// The HF templates handle both shapes via `content is string`
|
||||||
|
// checks or content-array iteration.
|
||||||
|
let messages_json: Vec<Value> = messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
let content_value = match &m.content {
|
||||||
|
MessageContent::Text(s) => Value::String(s.clone()),
|
||||||
|
MessageContent::Parts(parts) => Value::Array(parts.clone()),
|
||||||
|
};
|
||||||
|
let mut obj = serde_json::Map::new();
|
||||||
|
obj.insert("role".into(), Value::String(m.role.clone()));
|
||||||
|
obj.insert("content".into(), content_value);
|
||||||
|
// Forward extras (e.g. tool_calls on assistant turns,
|
||||||
|
// tool_call_id on tool result turns). HF templates that
|
||||||
|
// need them read e.g. `message.tool_calls`.
|
||||||
|
if let Value::Object(extras) = &m.extra {
|
||||||
|
for (k, v) in extras {
|
||||||
|
obj.insert(k.clone(), v.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value::Object(obj)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Build the kwargs context. Add base bindings the template
|
||||||
|
// expects (`messages`, `add_generation_prompt`, `tools`) plus
|
||||||
|
// anything the caller passed in `chat_template_kwargs`. Caller
|
||||||
|
// kwargs override the defaults so `add_generation_prompt: false`
|
||||||
|
// from the request actually wins.
|
||||||
|
let mut ctx_map = serde_json::Map::new();
|
||||||
|
ctx_map.insert("messages".into(), Value::Array(messages_json));
|
||||||
|
ctx_map.insert("add_generation_prompt".into(), Value::Bool(true));
|
||||||
|
if !tools.is_null() {
|
||||||
|
ctx_map.insert("tools".into(), tools.clone());
|
||||||
|
}
|
||||||
|
if let Value::Object(kwargs_obj) = kwargs {
|
||||||
|
for (k, v) in kwargs_obj {
|
||||||
|
ctx_map.insert(k.clone(), v.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// `Template::render` takes any Serialize value; serde_json's
|
||||||
|
// `Value` implements it natively, so we pass the assembled
|
||||||
|
// context object directly without going through the
|
||||||
|
// `context!` macro (which expects minijinja-native values).
|
||||||
|
tmpl.render(Value::Object(ctx_map))
|
||||||
|
.context("render chat_template")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
fn user_msg(text: &str) -> ChatMessage {
|
||||||
|
ChatMessage {
|
||||||
|
role: "user".into(),
|
||||||
|
content: MessageContent::Text(text.into()),
|
||||||
|
extra: Value::Object(Default::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assistant_msg(text: &str) -> ChatMessage {
|
||||||
|
ChatMessage {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: MessageContent::Text(text.into()),
|
||||||
|
extra: Value::Object(Default::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Minimal Qwen3-style template — enough surface to confirm
|
||||||
|
/// our renderer threads role + content correctly without
|
||||||
|
/// loading a real model's tokenizer_config.json.
|
||||||
|
const QWEN3_LIKE: &str = "{%- for message in messages -%}\
|
||||||
|
<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n\
|
||||||
|
{%- endfor -%}\
|
||||||
|
{%- if add_generation_prompt -%}<|im_start|>assistant\n{%- endif -%}";
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn renders_basic_conversation() {
|
||||||
|
let prompt = render_chat_template(
|
||||||
|
QWEN3_LIKE,
|
||||||
|
&[user_msg("hello"), assistant_msg("hi"), user_msg("bye")],
|
||||||
|
&Value::Null,
|
||||||
|
&Value::Null,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
// Structural assertions — the exact whitespace produced
|
||||||
|
// by a given template is a Jinja-trim concern that varies
|
||||||
|
// per real chat_template. What matters is that every
|
||||||
|
// turn's role + content thread through in order, and that
|
||||||
|
// the generation cue lands at the end.
|
||||||
|
assert!(
|
||||||
|
prompt.contains("<|im_start|>user\nhello<|im_end|>"),
|
||||||
|
"first user turn missing: {prompt}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
prompt.contains("<|im_start|>assistant\nhi<|im_end|>"),
|
||||||
|
"assistant turn missing: {prompt}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
prompt.contains("<|im_start|>user\nbye<|im_end|>"),
|
||||||
|
"second user turn missing: {prompt}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
prompt.ends_with("<|im_start|>assistant")
|
||||||
|
|| prompt.ends_with("<|im_start|>assistant\n"),
|
||||||
|
"generation cue missing at end: {prompt}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn kwargs_are_threaded_into_template_context() {
|
||||||
|
// Replica of Qwen3's enable_thinking branch in
|
||||||
|
// simplified form. When the kwarg is false, the model's
|
||||||
|
// template injects an empty `<think>...</think>` block
|
||||||
|
// before the generation cue — pre-filling the model's
|
||||||
|
// reasoning slot with "no thinking" so the model emits
|
||||||
|
// the answer directly.
|
||||||
|
let template = "{%- if enable_thinking is defined and enable_thinking is false -%}\
|
||||||
|
NO_THINK\
|
||||||
|
{%- else -%}\
|
||||||
|
THINK_OK\
|
||||||
|
{%- endif -%}";
|
||||||
|
let r_disabled = render_chat_template(
|
||||||
|
template,
|
||||||
|
&[],
|
||||||
|
&Value::Null,
|
||||||
|
&json!({ "enable_thinking": false }),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(r_disabled, "NO_THINK");
|
||||||
|
let r_default = render_chat_template(template, &[], &Value::Null, &Value::Null).unwrap();
|
||||||
|
assert_eq!(r_default, "THINK_OK");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn missing_template_field_returns_none() {
|
||||||
|
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-missing-field.json");
|
||||||
|
std::fs::write(&tmp, r#"{"some_other_field": 1}"#).unwrap();
|
||||||
|
assert!(load_chat_template_from(&tmp).is_none());
|
||||||
|
let _ = std::fs::remove_file(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_template_from_string_field() {
|
||||||
|
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-string.json");
|
||||||
|
std::fs::write(
|
||||||
|
&tmp,
|
||||||
|
r#"{"chat_template": "hello {{ messages[0].content }}"}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let t = load_chat_template_from(&tmp).expect("template loaded");
|
||||||
|
assert!(t.contains("messages[0].content"));
|
||||||
|
let _ = std::fs::remove_file(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_template_from_array_form() {
|
||||||
|
// Some HF models ship `chat_template` as `[{name, template}, ...]`.
|
||||||
|
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-array.json");
|
||||||
|
std::fs::write(
|
||||||
|
&tmp,
|
||||||
|
r#"{"chat_template": [{"name": "default", "template": "ARR"}]}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let t = load_chat_template_from(&tmp).expect("template loaded");
|
||||||
|
assert_eq!(t, "ARR");
|
||||||
|
let _ = std::fs::remove_file(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn missing_file_returns_none_quietly() {
|
||||||
|
let absent = std::path::PathBuf::from("/definitely/not/a/real/path.json");
|
||||||
|
assert!(load_chat_template_from(&absent).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unparseable_returns_none() {
|
||||||
|
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-garbage.json");
|
||||||
|
std::fs::write(&tmp, b"{not valid json").unwrap();
|
||||||
|
assert!(load_chat_template_from(&tmp).is_none());
|
||||||
|
let _ = std::fs::remove_file(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn kill_switch_recognises_truthy_falsy_values() {
|
||||||
|
// Test against the actual env var so callers see the
|
||||||
|
// same behaviour as production. Serialise via a
|
||||||
|
// mutex — see path_util.rs for the pattern.
|
||||||
|
use std::sync::Mutex;
|
||||||
|
static LOCK: Mutex<()> = Mutex::new(());
|
||||||
|
let _g = LOCK.lock().unwrap();
|
||||||
|
let prior = std::env::var(KILL_SWITCH_ENV).ok();
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var(KILL_SWITCH_ENV);
|
||||||
|
}
|
||||||
|
assert!(chat_templates_enabled());
|
||||||
|
for value in ["false", "0", "no", "off", "FALSE", " no "] {
|
||||||
|
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
|
||||||
|
assert!(!chat_templates_enabled(), "value {value:?} should disable");
|
||||||
|
}
|
||||||
|
for value in ["true", "1", "yes", ""] {
|
||||||
|
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
|
||||||
|
assert!(chat_templates_enabled(), "value {value:?} should enable");
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
match prior {
|
||||||
|
Some(p) => std::env::set_var(KILL_SWITCH_ENV, p),
|
||||||
|
None => std::env::remove_var(KILL_SWITCH_ENV),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn message_extras_thread_through_for_tool_calls() {
|
||||||
|
// HF templates read assistant.tool_calls and tool
|
||||||
|
// turns' tool_call_id. Confirm our extras flatten into
|
||||||
|
// the message object the template iterates.
|
||||||
|
let mut extras = serde_json::Map::new();
|
||||||
|
extras.insert(
|
||||||
|
"tool_calls".into(),
|
||||||
|
json!([{"id": "t1", "function": {"name": "x", "arguments": "{}"}}]),
|
||||||
|
);
|
||||||
|
let msg = ChatMessage {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: MessageContent::Text(String::new()),
|
||||||
|
extra: Value::Object(extras),
|
||||||
|
};
|
||||||
|
let template = "{{ messages[0].tool_calls[0].id }}";
|
||||||
|
let rendered = render_chat_template(template, &[msg], &Value::Null, &Value::Null).unwrap();
|
||||||
|
assert_eq!(rendered, "t1");
|
||||||
|
}
|
||||||
|
}
|
||||||
810
crates/neuron/src/harness/device_worker/dispatch.rs
Normal file
810
crates/neuron/src/harness/device_worker/dispatch.rs
Normal file
@@ -0,0 +1,810 @@
|
|||||||
|
//! Synchronous dispatch loop running on the device worker thread.
|
||||||
|
//!
|
||||||
|
//! `run()` is the thread's entry point. It binds the CUDA context for
|
||||||
|
//! its device on startup, then pulls `Job`s off the channel one at a
|
||||||
|
//! time and runs the corresponding handler. The handlers are
|
||||||
|
//! synchronous by design — the only async on this thread is the
|
||||||
|
//! one-line `oneshot::Sender::send` call to ship the reply back, which
|
||||||
|
//! is non-blocking.
|
||||||
|
//!
|
||||||
|
//! Phase 2 handles QueryVram, TransferIn, DropArch, ClearKv,
|
||||||
|
//! ForwardLogits, Shutdown. Phase 3 will add the TP variants
|
||||||
|
//! (NcclInit, NcclSanity, TpLoadShard, TpForward, TpClearKv) and the
|
||||||
|
//! ARCH model state in this state slab will gain a companion
|
||||||
|
//! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`.
|
||||||
|
|
||||||
|
use crate::harness::candle::ModelArch;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use crate::harness::device_worker::jobs::TpHandle;
|
||||||
|
use crate::harness::device_worker::jobs::{ArchHandle, Job};
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use crate::harness::tp::TpLeaderModel;
|
||||||
|
use crate::harness::tp::nccl_state::NcclState;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::mpsc::Receiver;
|
||||||
|
|
||||||
|
/// Per-thread state owned by the worker. On CUDA builds the `Arc<CudaContext>`
|
||||||
|
/// is created and bound at thread startup; on CPU builds the struct
|
||||||
|
/// is mostly empty.
|
||||||
|
struct DeviceWorkerState {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
device_index: u32,
|
||||||
|
/// Candle `Device` constructed at startup. Used by handlers (e.g.
|
||||||
|
/// `ForwardLogits`) to build input tensors against the right
|
||||||
|
/// device. Falls back to `Device::Cpu` if CUDA init fails.
|
||||||
|
device: candle_core::Device,
|
||||||
|
/// Boxed `ModelArch` slab. Indexed by an opaque `ArchHandle` minted
|
||||||
|
/// by `TransferIn`. The Box means the entry's address is stable
|
||||||
|
/// across HashMap rehashes (relevant only when we later hand out
|
||||||
|
/// `&mut ModelArch` references — for Phase 2 every handler runs
|
||||||
|
/// `&mut` via `get_mut`, no long-lived borrows).
|
||||||
|
models: HashMap<ArchHandle, Box<ModelArch>>,
|
||||||
|
/// Counter for minting fresh `ArchHandle`s. Each `TransferIn`
|
||||||
|
/// increments and returns the new value. Wraps at u64::MAX after
|
||||||
|
/// ~10^19 model loads — not a practical concern.
|
||||||
|
next_handle: u64,
|
||||||
|
/// Leader's NCCL state. Populated by `Job::NcclInit`; the
|
||||||
|
/// underlying `Comm`'s libnccl handle lives bound to this thread
|
||||||
|
/// for its entire lifetime. Subprocess workers maintain their own
|
||||||
|
/// `NcclState` in their own processes — that's not visible from
|
||||||
|
/// here.
|
||||||
|
#[allow(dead_code)] // Read only via methods on NcclState
|
||||||
|
nccl: NcclState,
|
||||||
|
/// TP leader model slab. Same lifecycle as `models`; separate
|
||||||
|
/// namespace so `ArchHandle` and `TpHandle` can't collide.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
tp_models: HashMap<TpHandle, Box<TpLeaderModel>>,
|
||||||
|
/// Counter for minting fresh `TpHandle`s.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
next_tp_handle: u64,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
/// `None` only if `CudaContext::new()` failed — in that case the
|
||||||
|
/// thread still runs so the handle's lifecycle stays uniform, but
|
||||||
|
/// every job that touches CUDA falls through to a zero reply with
|
||||||
|
/// a log warning.
|
||||||
|
ctx: Option<Arc<candle_core::cuda::cudarc::driver::CudaContext>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Worker thread entry point. Runs until `Job::Shutdown` arrives or
|
||||||
|
/// the channel sender is dropped (which happens when the last
|
||||||
|
/// `DeviceWorkerHandle` `Arc` is dropped without an explicit
|
||||||
|
/// `shutdown()`).
|
||||||
|
pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool>) {
|
||||||
|
let mut state = init_state(device_index);
|
||||||
|
tracing::info!(device_index, "device worker started");
|
||||||
|
|
||||||
|
while let Ok(job) = rx.recv() {
|
||||||
|
// Shutdown is processed unconditionally so a poisoned worker
|
||||||
|
// still exits when asked. Matching by reference first so we
|
||||||
|
// can fall through to the consume-match below.
|
||||||
|
if matches!(&job, Job::Shutdown) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if poisoned.load(Ordering::Acquire) {
|
||||||
|
// Drain-only mode: reply with a poisoned error without
|
||||||
|
// touching CUDA. Phase 1/2 never set the flag from the
|
||||||
|
// dispatch loop itself (no driver errors classified yet),
|
||||||
|
// but tests use `DeviceWorkerHandle::set_poisoned()` to
|
||||||
|
// simulate this state.
|
||||||
|
drain_poisoned(job, device_index);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match job {
|
||||||
|
Job::QueryVram { reply } => {
|
||||||
|
let result = query_vram(&state);
|
||||||
|
// If the caller dropped its receiver (request cancelled,
|
||||||
|
// gateway timed out) the send fails — fine, we just
|
||||||
|
// discard the reply.
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
Job::LoadGguf {
|
||||||
|
gguf_path,
|
||||||
|
model_id,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = load_gguf_inner(&state.device, &gguf_path, &model_id)
|
||||||
|
.map(|arch| insert_arch(&mut state, Box::new(arch)));
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
Job::LoadDense {
|
||||||
|
config_path,
|
||||||
|
safetensors_paths,
|
||||||
|
model_id,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result =
|
||||||
|
load_dense_inner(&state.device, &config_path, &safetensors_paths, &model_id)
|
||||||
|
.map(|arch| insert_arch(&mut state, Box::new(arch)));
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
Job::DropArch { handle, reply } => {
|
||||||
|
let removed = state.models.remove(&handle);
|
||||||
|
let was_present = removed.is_some();
|
||||||
|
// Explicit drop on this thread — runs the Box<ModelArch>
|
||||||
|
// Drop with the CUDA context bound here, which frees
|
||||||
|
// all device tensors on the right context. The Drop is
|
||||||
|
// implicit on the `removed` value going out of scope at
|
||||||
|
// the end of the arm; calling drop() explicitly just
|
||||||
|
// makes the intent visible.
|
||||||
|
drop(removed);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index,
|
||||||
|
handle = handle.0,
|
||||||
|
was_present,
|
||||||
|
slab_size = state.models.len(),
|
||||||
|
"device worker: model dropped"
|
||||||
|
);
|
||||||
|
let _ = reply.send(());
|
||||||
|
}
|
||||||
|
Job::ClearKv { handle, reply } => {
|
||||||
|
let result = match state.models.get_mut(&handle) {
|
||||||
|
Some(arch) => arch.clear_kv_cache(),
|
||||||
|
None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)),
|
||||||
|
};
|
||||||
|
if result.is_ok() {
|
||||||
|
trim_device_pool(&state);
|
||||||
|
}
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
Job::ForwardLogits {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = forward_logits(&mut state, handle, &tokens, offset);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
Job::NcclInit {
|
||||||
|
cfg,
|
||||||
|
comm_id_hex,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let resp = state.nccl.init(cfg, &comm_id_hex);
|
||||||
|
let _ = reply.send(resp);
|
||||||
|
}
|
||||||
|
Job::NcclSanity { reply } => {
|
||||||
|
let resp = state.nccl.sanity_check();
|
||||||
|
let _ = reply.send(resp);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpLoadShard {
|
||||||
|
model_id,
|
||||||
|
config_json,
|
||||||
|
safetensors_paths,
|
||||||
|
dtype,
|
||||||
|
quant,
|
||||||
|
world_size,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = tp_load_shard_inner(
|
||||||
|
&mut state,
|
||||||
|
&model_id,
|
||||||
|
&config_json,
|
||||||
|
&safetensors_paths,
|
||||||
|
dtype,
|
||||||
|
quant.as_deref(),
|
||||||
|
world_size,
|
||||||
|
);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::DropTp { handle, reply } => {
|
||||||
|
let removed = state.tp_models.remove(&handle);
|
||||||
|
let was_present = removed.is_some();
|
||||||
|
drop(removed);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index,
|
||||||
|
tp_handle = handle.0,
|
||||||
|
was_present,
|
||||||
|
slab_size = state.tp_models.len(),
|
||||||
|
"device worker: TP model dropped"
|
||||||
|
);
|
||||||
|
let _ = reply.send(());
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpClearKv { handle, reply } => {
|
||||||
|
let result = match state.tp_models.get_mut(&handle) {
|
||||||
|
Some(model) => {
|
||||||
|
model.clear_kv_cache();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
None => Err(anyhow::anyhow!(
|
||||||
|
"TpClearKv: no TP model for handle {}",
|
||||||
|
handle.0
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
if result.is_ok() {
|
||||||
|
trim_device_pool(&state);
|
||||||
|
}
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpForwardLogits {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
// Handled by the matches!() check above; reaching here
|
||||||
|
// means a Shutdown slipped past which is a bug.
|
||||||
|
Job::Shutdown => unreachable!("Shutdown should break above"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
let tp_slab_size = state.tp_models.len();
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
let tp_slab_size = 0_usize;
|
||||||
|
tracing::info!(
|
||||||
|
device_index,
|
||||||
|
slab_size = state.models.len(),
|
||||||
|
tp_slab_size,
|
||||||
|
"device worker exiting; dropping remaining models"
|
||||||
|
);
|
||||||
|
// Drops every model in the slab on this thread before the function
|
||||||
|
// returns. Critical for CUDA tensors: dropping on a thread that
|
||||||
|
// doesn't have the context bound is UB. Phase 2 still runs Drop
|
||||||
|
// via the slab going out of scope, which is correct as long as no
|
||||||
|
// pre-poisoned state lurks in here — see the poisoned-mode
|
||||||
|
// semantics in mod.rs for the Phase 3+ refinement.
|
||||||
|
}
|
||||||
|
|
||||||
|
fn init_state(device_index: u32) -> DeviceWorkerState {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
use candle_core::cuda::cudarc::driver::CudaContext;
|
||||||
|
// Construct a candle Device first — cudarc returns the
|
||||||
|
// primary context for this index on subsequent calls, so
|
||||||
|
// CudaContext::new and Device::new_cuda end up sharing state.
|
||||||
|
let (device, ctx) = match candle_core::Device::new_cuda(device_index as usize) {
|
||||||
|
Ok(device) => match CudaContext::new(device_index as usize) {
|
||||||
|
Ok(ctx) => {
|
||||||
|
if let Err(e) = ctx.bind_to_thread() {
|
||||||
|
tracing::warn!(
|
||||||
|
device_index,
|
||||||
|
error = ?e,
|
||||||
|
"device worker: bind_to_thread failed; \
|
||||||
|
operations will still rebind per-call"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::info!(device_index, "device worker bound CUDA context");
|
||||||
|
}
|
||||||
|
(device, Some(ctx))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
device_index,
|
||||||
|
error = ?e,
|
||||||
|
"device worker: CudaContext::new failed; \
|
||||||
|
vram queries will return (0, 0), forward will error"
|
||||||
|
);
|
||||||
|
(device, None)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
device_index,
|
||||||
|
error = %e,
|
||||||
|
"device worker: Device::new_cuda failed; falling back to CPU device"
|
||||||
|
);
|
||||||
|
(candle_core::Device::Cpu, None)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
DeviceWorkerState {
|
||||||
|
device_index,
|
||||||
|
device,
|
||||||
|
models: HashMap::new(),
|
||||||
|
next_handle: 1,
|
||||||
|
nccl: NcclState::new(),
|
||||||
|
tp_models: HashMap::new(),
|
||||||
|
next_tp_handle: 1,
|
||||||
|
ctx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
{
|
||||||
|
DeviceWorkerState {
|
||||||
|
device_index,
|
||||||
|
device: candle_core::Device::Cpu,
|
||||||
|
models: HashMap::new(),
|
||||||
|
next_handle: 1,
|
||||||
|
nccl: NcclState::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn query_vram(state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
||||||
|
use candle_core::cuda::cudarc::driver::result;
|
||||||
|
if state.ctx.is_none() {
|
||||||
|
return Ok((0, 0));
|
||||||
|
}
|
||||||
|
// The context was bound in init_state. cudarc's `mem_get_info`
|
||||||
|
// reads from the current context on the calling thread; since we
|
||||||
|
// bound on startup and we never spawn child threads from this
|
||||||
|
// worker, the binding holds.
|
||||||
|
match result::mem_get_info() {
|
||||||
|
Ok((free, total)) => Ok((
|
||||||
|
(free / (1024 * 1024)) as u64,
|
||||||
|
(total / (1024 * 1024)) as u64,
|
||||||
|
)),
|
||||||
|
Err(e) => Err(anyhow::anyhow!("mem_get_info: {e:?}")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
||||||
|
Ok((0, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Force cudarc's stream-ordered memory pool to release every block it
|
||||||
|
/// is holding back to the system. After `ConcatKvCache::reset()` drops
|
||||||
|
/// its tensors, the underlying `CudaSlice::drop` calls `cuMemFreeAsync`,
|
||||||
|
/// which returns the blocks to the device's default mempool but not to
|
||||||
|
/// the OS — `mem_get_info` still reports them as used. The next
|
||||||
|
/// request's prefill then sees a falsely-small free pool and either
|
||||||
|
/// OOMs or trips cuBLAS into `CUBLAS_STATUS_INTERNAL_ERROR`.
|
||||||
|
///
|
||||||
|
/// Calling `cuMemPoolTrimTo(pool, 0)` after each `clear_kv_cache`
|
||||||
|
/// returns those blocks. We synchronize first so any pending
|
||||||
|
/// `cuMemFreeAsync` operations have settled. Failures are non-fatal:
|
||||||
|
/// the pool may not exist on legacy drivers, or a transient driver
|
||||||
|
/// error may prevent the trim — neither breaks correctness, the next
|
||||||
|
/// request just sees a less-recovered free pool.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn trim_device_pool(state: &DeviceWorkerState) {
|
||||||
|
use candle_core::cuda::cudarc::driver::result::{device, mem_pool};
|
||||||
|
let Some(ctx) = state.ctx.as_ref() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let (before_free, _) = match query_vram(state) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => (0, 0),
|
||||||
|
};
|
||||||
|
if let Err(e) = ctx.synchronize() {
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
error = ?e,
|
||||||
|
"trim_device_pool: synchronize failed; skipping trim"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let dev = ctx.cu_device();
|
||||||
|
let pool = match unsafe { device::get_default_mem_pool(dev) } {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
error = ?e,
|
||||||
|
"trim_device_pool: get_default_mem_pool failed"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Err(e) = unsafe { mem_pool::trim_to(pool, 0) } {
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
error = ?e,
|
||||||
|
"trim_device_pool: cuMemPoolTrimTo failed"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let (after_free, _) = match query_vram(state) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => (0, 0),
|
||||||
|
};
|
||||||
|
let freed_mb = after_free.saturating_sub(before_free);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
before_free_mb = before_free,
|
||||||
|
after_free_mb = after_free,
|
||||||
|
freed_mb,
|
||||||
|
"trim_device_pool: trimmed pool"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn trim_device_pool(_state: &DeviceWorkerState) {}
|
||||||
|
|
||||||
|
/// Insert a freshly-built `ModelArch` into the slab and mint a fresh
|
||||||
|
/// `ArchHandle`. Used by both `LoadGguf` and `LoadDense` dispatch
|
||||||
|
/// handlers — they differ only in *how* the arch is built; the
|
||||||
|
/// post-construction bookkeeping is identical.
|
||||||
|
fn insert_arch(state: &mut DeviceWorkerState, arch: Box<ModelArch>) -> ArchHandle {
|
||||||
|
let handle = ArchHandle(state.next_handle);
|
||||||
|
state.next_handle = state.next_handle.wrapping_add(1);
|
||||||
|
state.models.insert(handle, arch);
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
handle = handle.0,
|
||||||
|
slab_size = state.models.len(),
|
||||||
|
"device worker: model inserted"
|
||||||
|
);
|
||||||
|
handle
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a GGUF (pre-quantized) model on the worker thread. Pulled
|
||||||
|
/// verbatim from the spawn_blocking closure that used to live in
|
||||||
|
/// `CandleHarness::load_arch_gguf`; the only change is that `device`
|
||||||
|
/// is now `state.device` (the worker's permanently-bound device).
|
||||||
|
fn load_gguf_inner(
|
||||||
|
device: &candle_core::Device,
|
||||||
|
gguf_path: &std::path::Path,
|
||||||
|
model_id: &str,
|
||||||
|
) -> anyhow::Result<ModelArch> {
|
||||||
|
use anyhow::Context;
|
||||||
|
use candle_core::DType;
|
||||||
|
use candle_core::quantized::gguf_file;
|
||||||
|
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
|
||||||
|
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
||||||
|
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
|
||||||
|
|
||||||
|
tracing::info!(model = %model_id, path = ?gguf_path, "loading GGUF");
|
||||||
|
let mut file = std::fs::File::open(gguf_path).context("open GGUF file")?;
|
||||||
|
let content =
|
||||||
|
gguf_file::Content::read(&mut file).map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
|
||||||
|
|
||||||
|
let architecture = content
|
||||||
|
.metadata
|
||||||
|
.get("general.architecture")
|
||||||
|
.and_then(|v| v.to_string().ok().cloned())
|
||||||
|
.unwrap_or_default();
|
||||||
|
tracing::info!(architecture = %architecture, "GGUF architecture");
|
||||||
|
|
||||||
|
// The `general.architecture` GGUF metadata key follows
|
||||||
|
// llama.cpp conventions (lowercase, no underscores in some
|
||||||
|
// cases) — `qwen3moe`, not `qwen3_moe`.
|
||||||
|
match architecture.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let weights = QuantizedQwen3Weights::from_gguf(content, &mut file, device)
|
||||||
|
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3Quantized(weights))
|
||||||
|
}
|
||||||
|
"qwen3moe" => {
|
||||||
|
// GGUFQWenMoE takes an explicit compute dtype alongside
|
||||||
|
// the device — F16 matches the GGUF weights' typical
|
||||||
|
// accumulation precision and gives the best tokens/sec on
|
||||||
|
// consumer cards.
|
||||||
|
let weights = GGUFQWenMoE::from_gguf(content, &mut file, device, DType::F16)
|
||||||
|
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3MoeQuantized(weights))
|
||||||
|
}
|
||||||
|
"llama" => {
|
||||||
|
let weights = QuantizedLlamaWeights::from_gguf(content, &mut file, device)
|
||||||
|
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
|
||||||
|
Ok(ModelArch::LlamaQuantized(weights))
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unsupported GGUF architecture '{other}'; quantized path supports \
|
||||||
|
qwen3, qwen3moe, llama"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a dense safetensors model on the worker thread.
|
||||||
|
fn load_dense_inner(
|
||||||
|
device: &candle_core::Device,
|
||||||
|
config_path: &std::path::Path,
|
||||||
|
safetensors_paths: &[std::path::PathBuf],
|
||||||
|
model_id: &str,
|
||||||
|
) -> anyhow::Result<ModelArch> {
|
||||||
|
use anyhow::Context;
|
||||||
|
use candle_core::DType;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::llama as llama_dense;
|
||||||
|
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||||
|
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
|
||||||
|
|
||||||
|
let cfg_text = std::fs::read_to_string(config_path).context("read config.json")?;
|
||||||
|
crate::harness::candle::check_dense_config_supported(&cfg_text, model_id)?;
|
||||||
|
// Peek at model_type to choose the family before the typed
|
||||||
|
// deserialize — each family has its own Config.
|
||||||
|
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.get("model_type"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
tracing::info!(
|
||||||
|
model = %model_id,
|
||||||
|
model_type = %model_type,
|
||||||
|
shards = safetensors_paths.len(),
|
||||||
|
"loading dense model from safetensors"
|
||||||
|
);
|
||||||
|
|
||||||
|
// bf16 is the canonical distribution dtype for Qwen3 / Llama 3 /
|
||||||
|
// Qwen3 MoE. CUDA on Ada+ has hardware bf16; Ampere has it too.
|
||||||
|
// CPU emulates.
|
||||||
|
let dtype = DType::BF16;
|
||||||
|
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
|
||||||
|
// mutation by another process while we hold the mapping is UB.
|
||||||
|
// We trust the HF cache is immutable-by-design.
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(safetensors_paths, dtype, device)
|
||||||
|
.context("build VarBuilder over safetensors")?
|
||||||
|
};
|
||||||
|
|
||||||
|
match model_type.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let cfg: qwen3_dense::Config =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
||||||
|
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
|
||||||
|
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3Dense(model))
|
||||||
|
}
|
||||||
|
"qwen3_moe" => {
|
||||||
|
let cfg: qwen3_moe_dense::Config =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
|
||||||
|
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
|
||||||
|
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
|
||||||
|
Ok(ModelArch::Qwen3MoeDense(model))
|
||||||
|
}
|
||||||
|
"llama" => {
|
||||||
|
let cfg: llama_dense::LlamaConfig =
|
||||||
|
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
|
||||||
|
let config = cfg.into_config(false);
|
||||||
|
let cache = llama_dense::Cache::new(true, dtype, &config, device)
|
||||||
|
.context("build Llama Cache")?;
|
||||||
|
let model = llama_dense::Llama::load(vb, &config)
|
||||||
|
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
|
||||||
|
Ok(ModelArch::LlamaDense(Box::new(
|
||||||
|
crate::harness::candle::LlamaDense::from_parts(
|
||||||
|
model,
|
||||||
|
cache,
|
||||||
|
config,
|
||||||
|
dtype,
|
||||||
|
device.clone(),
|
||||||
|
),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
"qwen3_5" => {
|
||||||
|
let cfg: crate::harness::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
|
||||||
|
.context("parse Qwen3-Next (qwen3_5) config.json")?;
|
||||||
|
let sharded_vb = unsafe {
|
||||||
|
candle_nn::var_builder::ShardedSafeTensors::var_builder(
|
||||||
|
safetensors_paths,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.context("build ShardedVarBuilder for Qwen3-Next")?
|
||||||
|
};
|
||||||
|
let model = crate::harness::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
|
||||||
|
.context("build Qwen3-Next dense model")?;
|
||||||
|
Ok(ModelArch::Qwen3_5Dense(model))
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unrouted supported model_type '{other}' — \
|
||||||
|
DENSE_SUPPORTED_MODEL_TYPES and load_dense_inner \
|
||||||
|
must stay in sync"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load the leader's TP shard on the worker thread. Reads the Comm
|
||||||
|
/// directly from `state.nccl`; no cross-thread Arc<Comm> transfer.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn tp_load_shard_inner(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
model_id: &str,
|
||||||
|
config_json: &str,
|
||||||
|
safetensors_paths: &[std::path::PathBuf],
|
||||||
|
dtype: candle_core::DType,
|
||||||
|
quant: Option<&str>,
|
||||||
|
world_size: u32,
|
||||||
|
) -> anyhow::Result<TpHandle> {
|
||||||
|
use anyhow::Context;
|
||||||
|
use candle_nn::var_builder::ShardedSafeTensors;
|
||||||
|
|
||||||
|
let comm = state.nccl.comm().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("TpLoadShard: NcclState has no Comm; call NcclInit first")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.get("model_type"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// SAFETY: same invariant as the single-GPU dense path — the HF
|
||||||
|
// cache files are treated as immutable while the mmap is held.
|
||||||
|
let vb = unsafe {
|
||||||
|
ShardedSafeTensors::var_builder(safetensors_paths, dtype, &state.device)
|
||||||
|
.context("build ShardedVarBuilder over safetensors")?
|
||||||
|
};
|
||||||
|
let mmap = unsafe {
|
||||||
|
candle_core::safetensors::MmapedSafetensors::multi(safetensors_paths)
|
||||||
|
.context("build MmapedSafetensors for leader load")?
|
||||||
|
};
|
||||||
|
|
||||||
|
let loaded = match model_type.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let cfg: crate::harness::tp::tp_qwen3::Config = serde_json::from_str(config_json)
|
||||||
|
.context("parse Qwen3 Config JSON for leader load")?;
|
||||||
|
TpLeaderModel::Qwen3(crate::harness::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
||||||
|
&cfg, &vb, 0, world_size, comm,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
"qwen3_5" => {
|
||||||
|
let cfg: crate::harness::tp::tp_qwen3_5::Config = serde_json::from_str(config_json)
|
||||||
|
.context("parse Qwen3-Next Config JSON for leader load")?;
|
||||||
|
let quant_dtype = crate::harness::tp::worker::parse_quant_string(quant)?;
|
||||||
|
TpLeaderModel::Qwen3_5(crate::harness::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
||||||
|
cfg,
|
||||||
|
&vb,
|
||||||
|
&mmap,
|
||||||
|
0,
|
||||||
|
world_size,
|
||||||
|
comm,
|
||||||
|
quant_dtype,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
rank = 0,
|
||||||
|
model = %model_id,
|
||||||
|
model_type = %model_type,
|
||||||
|
"loaded TP shard (leader)"
|
||||||
|
);
|
||||||
|
|
||||||
|
let handle = TpHandle(state.next_tp_handle);
|
||||||
|
state.next_tp_handle = state.next_tp_handle.wrapping_add(1);
|
||||||
|
state.tp_models.insert(handle, Box::new(loaded));
|
||||||
|
tracing::debug!(
|
||||||
|
device_index = state.device_index,
|
||||||
|
tp_handle = handle.0,
|
||||||
|
slab_size = state.tp_models.len(),
|
||||||
|
"device worker: TP model inserted"
|
||||||
|
);
|
||||||
|
Ok(handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP-equivalent of [`forward_logits`]: looks up the leader's
|
||||||
|
/// [`TpLeaderModel`] in the slab, runs its forward, copies the
|
||||||
|
/// `[vocab]` logits to a CPU `Vec<f32>`. The leader's `Arc<Comm>`
|
||||||
|
/// clones embedded in the TP layers' AllReduce ops fire from this
|
||||||
|
/// thread — same thread that bound the CUDA context and that holds
|
||||||
|
/// the `Comm` in `state.nccl`.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn tp_forward_logits(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: &[u32],
|
||||||
|
offset: usize,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
|
||||||
|
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
let model = state
|
||||||
|
.tp_models
|
||||||
|
.get_mut(&handle)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("TpForwardLogits: no model for handle {}", handle.0))?;
|
||||||
|
|
||||||
|
let logits = model.forward(&input, offset)?;
|
||||||
|
// ForCausalLM forward returns [B, 1, V] after the trailing
|
||||||
|
// .i((.., l - 1.., ..))?.apply(lm_head); squeeze both leading
|
||||||
|
// singleton dims to a rank-1 [V] tensor for sampling.
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?;
|
||||||
|
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||||
|
let values = logits.to_vec1::<f32>()?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
||||||
|
/// for sampling on the async caller. The model's `device()` (CUDA or
|
||||||
|
/// CPU) determines where the kernel runs; this fn doesn't care.
|
||||||
|
///
|
||||||
|
/// On CUDA, the `to_dtype(F32).flatten_all().to_vec1::<f32>()` chain
|
||||||
|
/// triggers the device → host copy. The copy runs synchronously on
|
||||||
|
/// this worker thread; the bound context owns the source allocation
|
||||||
|
/// so the transfer is straightforward.
|
||||||
|
fn forward_logits(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
handle: ArchHandle,
|
||||||
|
tokens: &[u32],
|
||||||
|
offset: usize,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
|
||||||
|
// Build the input tensor on the worker's own device. cudarc's
|
||||||
|
// primary-context model means `Device::new_cuda(idx)` shares state
|
||||||
|
// with the `CudaContext` we bound at startup, so this is the same
|
||||||
|
// device the ModelArch was loaded against.
|
||||||
|
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
let arch = state
|
||||||
|
.models
|
||||||
|
.get_mut(&handle)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("ForwardLogits: no model for handle {}", handle.0))?;
|
||||||
|
|
||||||
|
let logits = arch.forward(&input, offset)?;
|
||||||
|
// Copy to CPU f32. logits is already `[vocab]` (squeeze_to_vocab
|
||||||
|
// inside ModelArch::forward). The to_dtype handles bf16/f16 →
|
||||||
|
// f32 promotion for the sampler.
|
||||||
|
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||||
|
let values = logits.to_vec1::<f32>()?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reply to a job with the poisoned-worker error. Used when the worker
|
||||||
|
/// has flipped into drain-only mode after a CUDA driver error.
|
||||||
|
///
|
||||||
|
/// `Job::Shutdown` is filtered before reaching this fn so the match
|
||||||
|
/// only needs the data-carrying variants. As phases 2–4 add more
|
||||||
|
/// variants the match here grows; every variant must reply with the
|
||||||
|
/// poisoned error so callers never hang waiting for a worker that's
|
||||||
|
/// no longer running CUDA.
|
||||||
|
fn drain_poisoned(job: Job, device_index: u32) {
|
||||||
|
let err = || anyhow::anyhow!("device worker for device {device_index} is poisoned");
|
||||||
|
match job {
|
||||||
|
Job::QueryVram { reply } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::LoadGguf { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::LoadDense { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::DropArch { reply, .. } => {
|
||||||
|
// Drop reply is `()` — no error path. Send the unit so the
|
||||||
|
// caller's await resolves; the model handle is leaked in
|
||||||
|
// the worker's slab, but the whole slab gets `mem::forget`
|
||||||
|
// on shutdown anyway per the poisoned-thread design.
|
||||||
|
let _ = reply.send(());
|
||||||
|
}
|
||||||
|
Job::ClearKv { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::ForwardLogits { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::NcclInit { reply, .. } => {
|
||||||
|
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||||
|
kind: "device_worker_poisoned".into(),
|
||||||
|
message: format!("device worker {device_index} poisoned"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Job::NcclSanity { reply } => {
|
||||||
|
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||||
|
kind: "device_worker_poisoned".into(),
|
||||||
|
message: format!("device worker {device_index} poisoned"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpLoadShard { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::DropTp { reply, .. } => {
|
||||||
|
let _ = reply.send(());
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpClearKv { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpForwardLogits { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::Shutdown => {
|
||||||
|
// Filtered by the matches!() guard in run(); reaching
|
||||||
|
// here would be a logic error.
|
||||||
|
unreachable!("Shutdown is filtered before drain_poisoned");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
169
crates/neuron/src/harness/device_worker/jobs.rs
Normal file
169
crates/neuron/src/harness/device_worker/jobs.rs
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
//! Job variants accepted by the per-device worker thread.
|
||||||
|
//!
|
||||||
|
//! Each variant carries the inputs the synchronous dispatch handler
|
||||||
|
//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The
|
||||||
|
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
|
||||||
|
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
|
/// Opaque handle to a `ModelArch` stored in the worker thread's state
|
||||||
|
/// slab. Cheap to copy; `Send + Sync` so it crosses task boundaries
|
||||||
|
/// freely. The actual `Box<ModelArch>` it points to is owned by the
|
||||||
|
/// worker thread for the duration of the handle's lifetime — the only
|
||||||
|
/// way to drop the model is to send `Job::DropArch { handle }` so the
|
||||||
|
/// `Drop` impl runs on the thread with the bound CUDA context (the
|
||||||
|
/// invariant the whole refactor exists to guarantee).
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub struct ArchHandle(pub u64);
|
||||||
|
|
||||||
|
/// Opaque handle to a `TpLeaderModel` stored in the worker thread's
|
||||||
|
/// state slab. Same shape as [`ArchHandle`] but in a separate
|
||||||
|
/// namespace so the two slabs can coexist without ambiguity. Phase 3
|
||||||
|
/// introduces it; Phase 4 may unify the two slabs after the TP forward
|
||||||
|
/// path proves out.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub struct TpHandle(pub u64);
|
||||||
|
|
||||||
|
/// One unit of work for the device worker.
|
||||||
|
///
|
||||||
|
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
||||||
|
/// single-GPU inference primitives: transfer-in a freshly-loaded
|
||||||
|
/// `ModelArch`, drop it, clear its KV cache, and run one forward step
|
||||||
|
/// returning CPU-side logits ready for sampling on the async caller.
|
||||||
|
///
|
||||||
|
/// Sampling stays on the async side intentionally. The worker copies
|
||||||
|
/// logits to CPU (`Vec<f32>`) before reply, so the device-resident
|
||||||
|
/// tensor never escapes the worker thread and the async caller's
|
||||||
|
/// `LogitsProcessor::sample` runs entirely on the CPU candle backend
|
||||||
|
/// — no incidental context binding on a tokio worker thread.
|
||||||
|
pub enum Job {
|
||||||
|
/// Query free / total VRAM on the device. Returns
|
||||||
|
/// `(free_mb, total_mb)`. CPU builds and contexts that failed to
|
||||||
|
/// initialise reply with `(0, 0)` — matches today's
|
||||||
|
/// `device_vram_mb` sentinel so the log field values don't change.
|
||||||
|
QueryVram {
|
||||||
|
reply: oneshot::Sender<Result<(u64, u64)>>,
|
||||||
|
},
|
||||||
|
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||||
|
/// thread. The dispatch handler opens the GGUF file, parses
|
||||||
|
/// metadata, dispatches on `general.architecture`, and inserts
|
||||||
|
/// the resulting `ModelArch` into the slab. Returns the fresh
|
||||||
|
/// `ArchHandle`.
|
||||||
|
LoadGguf {
|
||||||
|
gguf_path: PathBuf,
|
||||||
|
model_id: String,
|
||||||
|
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||||
|
},
|
||||||
|
/// Load a dense safetensors single-GPU model on the worker
|
||||||
|
/// thread. The dispatch handler reads `config.json`, dispatches on
|
||||||
|
/// `model_type`, builds a `VarBuilder` over the mmap'd
|
||||||
|
/// safetensors, and inserts the resulting `ModelArch`.
|
||||||
|
LoadDense {
|
||||||
|
config_path: PathBuf,
|
||||||
|
safetensors_paths: Vec<PathBuf>,
|
||||||
|
model_id: String,
|
||||||
|
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||||
|
},
|
||||||
|
/// Remove the model from the slab and drop it. The `Drop` runs on
|
||||||
|
/// the worker thread so CUDA tensors release their memory on the
|
||||||
|
/// same context that allocated them.
|
||||||
|
DropArch {
|
||||||
|
handle: ArchHandle,
|
||||||
|
reply: oneshot::Sender<()>,
|
||||||
|
},
|
||||||
|
/// Reset the KV cache for this model. Called at the start of every
|
||||||
|
/// chat completion so a new request doesn't attend over the
|
||||||
|
/// previous one's tokens.
|
||||||
|
ClearKv {
|
||||||
|
handle: ArchHandle,
|
||||||
|
reply: oneshot::Sender<Result<()>>,
|
||||||
|
},
|
||||||
|
/// Run one forward step and copy the resulting `[vocab]` logits to
|
||||||
|
/// CPU. The caller takes the returned `Vec<f32>`, wraps it in a
|
||||||
|
/// CPU `Tensor`, and runs `apply_repeat_penalty` + sampling
|
||||||
|
/// without touching the device context. `offset` is the KV-cache
|
||||||
|
/// position before this step (0 for prefill, `prompt_len + i` for
|
||||||
|
/// the i-th decode step).
|
||||||
|
ForwardLogits {
|
||||||
|
handle: ArchHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
|
},
|
||||||
|
/// Initialize the leader's NCCL communicator. The worker's
|
||||||
|
/// `NcclState` mints the `Comm` here so its underlying
|
||||||
|
/// `ncclComm_t` and `CudaContext` live on the same thread as
|
||||||
|
/// every later `Comm::all_reduce` call. Reply is the worker
|
||||||
|
/// response shape used by the subprocess workers (`InitOk` on
|
||||||
|
/// success, `Error` on failure) so the calling
|
||||||
|
/// `WorkerPool::init_nccl` orchestration stays uniform.
|
||||||
|
///
|
||||||
|
/// Available on both cuda and no-cuda builds — the dispatch
|
||||||
|
/// handler calls `NcclState::init` which has a no-cuda stub that
|
||||||
|
/// replies with `cuda_feature_not_enabled`. Keeping the Job
|
||||||
|
/// variant ungated lets `WorkerPool::init_nccl` stay uniform.
|
||||||
|
NcclInit {
|
||||||
|
cfg: crate::harness::tp::worker::WorkerConfig,
|
||||||
|
comm_id_hex: String,
|
||||||
|
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||||
|
},
|
||||||
|
/// Run NCCL's all_reduce sanity check on the leader's rank 0.
|
||||||
|
/// Same response shape as `NcclInit`; also available on both
|
||||||
|
/// builds via the no-cuda `NcclState::sanity_check` stub.
|
||||||
|
NcclSanity {
|
||||||
|
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||||
|
},
|
||||||
|
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||||
|
/// handler reads `state.nccl.comm()` directly (no cross-thread
|
||||||
|
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
|
||||||
|
/// `TpLeaderModel` against that Comm. The model's embedded
|
||||||
|
/// `Arc<Comm>` clones, `CudaContext`, and all per-rank CUDA
|
||||||
|
/// tensors live on this thread for the model's lifetime.
|
||||||
|
/// Inserts into the TP slab and returns the fresh `TpHandle`.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
TpLoadShard {
|
||||||
|
model_id: String,
|
||||||
|
config_json: String,
|
||||||
|
safetensors_paths: Vec<PathBuf>,
|
||||||
|
dtype: candle_core::DType,
|
||||||
|
quant: Option<String>,
|
||||||
|
world_size: u32,
|
||||||
|
reply: oneshot::Sender<Result<TpHandle>>,
|
||||||
|
},
|
||||||
|
/// Drop the TP leader model on the worker thread. CUDA tensors
|
||||||
|
/// and `Arc<Comm>` clones held inside the model release on the
|
||||||
|
/// thread that allocated them.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
DropTp {
|
||||||
|
handle: TpHandle,
|
||||||
|
reply: oneshot::Sender<()>,
|
||||||
|
},
|
||||||
|
/// Reset the leader's KV cache for a TP model. Mirrors `ClearKv`
|
||||||
|
/// for single-GPU.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
TpClearKv {
|
||||||
|
handle: TpHandle,
|
||||||
|
reply: oneshot::Sender<Result<()>>,
|
||||||
|
},
|
||||||
|
/// Run one TP forward step on the leader's shard. Returns CPU-
|
||||||
|
/// side logits as a `Vec<f32>` so the async caller can sample
|
||||||
|
/// without holding a device tensor. The caller is also
|
||||||
|
/// responsible for fan-out to subprocess ranks and drain — only
|
||||||
|
/// the leader's forward moves into the worker thread.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
TpForwardLogits {
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
|
},
|
||||||
|
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
||||||
|
/// queued after this in the channel reply `Err` to their oneshot
|
||||||
|
/// senders (the senders are dropped on the worker's exit, which
|
||||||
|
/// the async-side `Receiver::await` maps to `WorkerError::Gone`).
|
||||||
|
Shutdown,
|
||||||
|
}
|
||||||
592
crates/neuron/src/harness/device_worker/mod.rs
Normal file
592
crates/neuron/src/harness/device_worker/mod.rs
Normal file
@@ -0,0 +1,592 @@
|
|||||||
|
//! Per-device CUDA worker thread.
|
||||||
|
//!
|
||||||
|
//! One dedicated OS thread per CUDA device the leader uses. The thread
|
||||||
|
//! binds the device's `CudaContext` once at startup and owns it for the
|
||||||
|
//! daemon's lifetime; all GPU operations and VRAM queries for that
|
||||||
|
//! device route through a `std::sync::mpsc` channel into this thread.
|
||||||
|
//! Tensors never escape the thread alive — replies cross the channel
|
||||||
|
//! as plain values (`u32` tokens, `(u64, u64)` mb numbers, `()`).
|
||||||
|
//!
|
||||||
|
//! Rationale, in order of weight:
|
||||||
|
//!
|
||||||
|
//! 1. **Context locality.** cudarc binds the CUDA context per OS thread
|
||||||
|
//! via `cuCtxSetCurrent`. With `tokio::task::spawn_blocking`, the
|
||||||
|
//! blocking thread chosen is arbitrary, so the context gets bound
|
||||||
|
//! onto a different thread each time and `device_vram_mb()` from an
|
||||||
|
//! async task binds it again on the *caller's* thread as a side
|
||||||
|
//! effect. Pinning the context to one named thread ends that.
|
||||||
|
//!
|
||||||
|
//! 2. **Drop safety.** `cudarc::driver::CudaContext`, every `CudaSlice`
|
||||||
|
//! inside a `Tensor`, and every `cudarc::nccl::Comm` call `cuMemFree`
|
||||||
|
//! / `cuCtxDestroy` / `ncclCommDestroy` during `Drop`. These must
|
||||||
|
//! run with the right context current. Owning everything in this
|
||||||
|
//! thread's state slab and dropping it via `Job::DropArch` /
|
||||||
|
//! `Job::Shutdown` is the only safe pattern.
|
||||||
|
//!
|
||||||
|
//! 3. **Poisoning blast radius.** When a CUDA driver error (illegal
|
||||||
|
//! address, OOM cascade) makes the context unrecoverable, today the
|
||||||
|
//! spawn_blocking thread carrying that bad state simply returns to
|
||||||
|
//! tokio's pool — invisible. With the per-device thread, the
|
||||||
|
//! poisoned flag lives on the thread itself; subsequent
|
||||||
|
//! `submit()` calls fast-reject at the channel boundary with a
|
||||||
|
//! clear "device worker is poisoned" error before any further CUDA
|
||||||
|
//! work is attempted.
|
||||||
|
//!
|
||||||
|
//! The TP worker subprocesses (`harness/tp/worker.rs`) are already this
|
||||||
|
//! pattern, just out-of-process. The in-process variant uses the same
|
||||||
|
//! discipline for rank 0.
|
||||||
|
//!
|
||||||
|
//! Phase 1 of the refactor exposes only `Job::QueryVram` + `Job::Shutdown`.
|
||||||
|
//! Forward, kv-cache clear, model load, and NCCL bring-up move in later
|
||||||
|
//! phases. See `/home/grenade/.claude/plans/plan-the-per-device-worker-abstract-micali.md`.
|
||||||
|
|
||||||
|
pub mod dispatch;
|
||||||
|
pub mod jobs;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::mpsc::{self, Sender};
|
||||||
|
use std::thread::JoinHandle;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub use jobs::TpHandle;
|
||||||
|
pub use jobs::{ArchHandle, Job};
|
||||||
|
|
||||||
|
/// Errors returned by `DeviceWorkerHandle` submit methods.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum WorkerError {
|
||||||
|
/// The worker's CUDA context was poisoned by an earlier driver
|
||||||
|
/// error. The thread is still alive (dropping it would re-touch
|
||||||
|
/// the broken context); it returns this error for every job
|
||||||
|
/// submitted until the daemon is restarted.
|
||||||
|
#[error(
|
||||||
|
"device worker for device {device_index} is poisoned \
|
||||||
|
(a prior CUDA driver error left the context unrecoverable); \
|
||||||
|
restart the daemon to recover"
|
||||||
|
)]
|
||||||
|
Poisoned { device_index: u32 },
|
||||||
|
/// The worker thread has exited (`Job::Shutdown` was processed or
|
||||||
|
/// the thread panicked). Subsequent `submit()` calls fail here
|
||||||
|
/// rather than blocking forever.
|
||||||
|
#[error("device worker for device {device_index} is no longer running")]
|
||||||
|
Gone { device_index: u32 },
|
||||||
|
/// The dispatched job returned an `Err`. Forwarded verbatim.
|
||||||
|
#[error(transparent)]
|
||||||
|
Job(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared handle to a per-device CUDA worker thread.
|
||||||
|
///
|
||||||
|
/// Cloning the `Arc` lets multiple `LoadedModel`s (and `TpLoadedModel`s)
|
||||||
|
/// share the same worker — there's one worker per CUDA device index,
|
||||||
|
/// not one per model.
|
||||||
|
pub struct DeviceWorkerHandle {
|
||||||
|
device_index: u32,
|
||||||
|
tx: Sender<Job>,
|
||||||
|
poisoned: Arc<AtomicBool>,
|
||||||
|
/// `Mutex<Option<JoinHandle>>` so `shutdown()` can take the handle
|
||||||
|
/// out without `&mut self` and so the inevitable `Drop` after
|
||||||
|
/// `shutdown()` doesn't double-join. The mutex is uncontended in
|
||||||
|
/// practice: only one caller ever takes the handle.
|
||||||
|
join: std::sync::Mutex<Option<JoinHandle<()>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DeviceWorkerHandle {
|
||||||
|
/// Spawn a new worker for the given CUDA device index.
|
||||||
|
///
|
||||||
|
/// The thread is named `cuda-dev-N` so it shows up legibly in
|
||||||
|
/// `top -H`, `pidstat -t`, and gdb backtraces. On CUDA builds, the
|
||||||
|
/// thread binds `CudaContext::new(N)` on startup; on CPU builds
|
||||||
|
/// (`--no-default-features`) the thread runs without a context and
|
||||||
|
/// every job that touches CUDA falls through to a zero return.
|
||||||
|
pub fn spawn(device_index: u32) -> anyhow::Result<Arc<Self>> {
|
||||||
|
let (tx, rx) = mpsc::channel::<Job>();
|
||||||
|
let poisoned = Arc::new(AtomicBool::new(false));
|
||||||
|
let poisoned_for_thread = Arc::clone(&poisoned);
|
||||||
|
let join = std::thread::Builder::new()
|
||||||
|
.name(format!("cuda-dev-{device_index}"))
|
||||||
|
.spawn(move || {
|
||||||
|
dispatch::run(device_index, rx, poisoned_for_thread);
|
||||||
|
})?;
|
||||||
|
Ok(Arc::new(Self {
|
||||||
|
device_index,
|
||||||
|
tx,
|
||||||
|
poisoned,
|
||||||
|
join: std::sync::Mutex::new(Some(join)),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device_index(&self) -> u32 {
|
||||||
|
self.device_index
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_poisoned(&self) -> bool {
|
||||||
|
self.poisoned.load(Ordering::Acquire)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark the worker's context as poisoned. Future `submit()` calls
|
||||||
|
/// short-circuit to `WorkerError::Poisoned` before sending. The
|
||||||
|
/// dispatch loop also flips into drain-only mode when it sees this
|
||||||
|
/// flag, so any jobs already in flight on the channel reply with
|
||||||
|
/// the same error without touching CUDA.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn set_poisoned(&self) {
|
||||||
|
self.poisoned.store(true, Ordering::Release);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send `Job::QueryVram`, await the worker's reply.
|
||||||
|
///
|
||||||
|
/// Returns `Ok((free_mb, total_mb))` on success, `Ok((0, 0))` on
|
||||||
|
/// CPU builds or when the device lacks a bound context, or an
|
||||||
|
/// error if the worker is poisoned, gone, or the query itself
|
||||||
|
/// failed inside cudarc.
|
||||||
|
pub async fn query_vram(&self) -> Result<(u64, u64), WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::QueryVram { reply: reply_tx })
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||||
|
/// thread. The hf-hub resolution happens on the async caller; the
|
||||||
|
/// resolved local `gguf_path` plus the spec's model_id are sent
|
||||||
|
/// into the worker which opens, parses, and constructs the
|
||||||
|
/// `ModelArch` on the right thread.
|
||||||
|
pub async fn load_gguf(
|
||||||
|
&self,
|
||||||
|
gguf_path: std::path::PathBuf,
|
||||||
|
model_id: String,
|
||||||
|
) -> Result<ArchHandle, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::LoadGguf {
|
||||||
|
gguf_path,
|
||||||
|
model_id,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a dense safetensors single-GPU model on the worker thread.
|
||||||
|
pub async fn load_dense(
|
||||||
|
&self,
|
||||||
|
config_path: std::path::PathBuf,
|
||||||
|
safetensors_paths: Vec<std::path::PathBuf>,
|
||||||
|
model_id: String,
|
||||||
|
) -> Result<ArchHandle, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::LoadDense {
|
||||||
|
config_path,
|
||||||
|
safetensors_paths,
|
||||||
|
model_id,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tell the worker to drop the `ModelArch` for `handle` on the
|
||||||
|
/// worker thread (so CUDA tensors release on the right context).
|
||||||
|
/// Returns `Ok(())` even if the handle wasn't in the slab — Drop
|
||||||
|
/// is idempotent. Reports `Gone` if the worker isn't running.
|
||||||
|
pub async fn drop_arch(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||||
|
// Poisoning doesn't block DropArch — even on a poisoned
|
||||||
|
// context we want callers to unblock and proceed with the
|
||||||
|
// unload bookkeeping. The dispatch handler under poison just
|
||||||
|
// replies `()` without touching the model (the actual Drop
|
||||||
|
// happens via mem::forget at thread exit per the poison
|
||||||
|
// protocol).
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::DropArch {
|
||||||
|
handle,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(()) => Ok(()),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the KV cache for the model at `handle`. Called at the
|
||||||
|
/// start of every chat completion so the new prompt doesn't
|
||||||
|
/// attend over the previous request's tokens.
|
||||||
|
pub async fn clear_kv_cache(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::ClearKv {
|
||||||
|
handle,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run one forward step and return the resulting `[vocab]` logits
|
||||||
|
/// as a CPU-side `Vec<f32>`. The caller then samples on a CPU
|
||||||
|
/// candle Tensor without ever binding the device context on its
|
||||||
|
/// tokio thread.
|
||||||
|
pub async fn forward_logits(
|
||||||
|
&self,
|
||||||
|
handle: ArchHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<Vec<f32>, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::ForwardLogits {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialise the leader's NCCL communicator. The reply uses
|
||||||
|
/// `WorkerResponse` (same shape subprocess workers use over stdio
|
||||||
|
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
|
||||||
|
/// subprocess responses uniformly. Available on no-cuda builds
|
||||||
|
/// too — the dispatch handler calls the no-cuda `NcclState::init`
|
||||||
|
/// stub which replies `cuda_feature_not_enabled`.
|
||||||
|
pub async fn nccl_init(
|
||||||
|
&self,
|
||||||
|
cfg: crate::harness::tp::worker::WorkerConfig,
|
||||||
|
comm_id_hex: String,
|
||||||
|
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::NcclInit {
|
||||||
|
cfg,
|
||||||
|
comm_id_hex,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
reply_rx.await.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run an NCCL sanity all_reduce on the leader's rank 0.
|
||||||
|
/// Available on no-cuda builds; replies with an error response.
|
||||||
|
pub async fn nccl_sanity(
|
||||||
|
&self,
|
||||||
|
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::NcclSanity { reply: reply_tx })
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
reply_rx.await.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||||
|
/// handler reads its own `NcclState`'s `Arc<Comm>` directly — no
|
||||||
|
/// cross-thread Comm transfer — and builds the `TpLeaderModel`
|
||||||
|
/// against it. Phase 4 replaces the Phase 3 Clone/TransferIn
|
||||||
|
/// bridge with this single Job.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn tp_load_shard(
|
||||||
|
&self,
|
||||||
|
model_id: String,
|
||||||
|
config_json: String,
|
||||||
|
safetensors_paths: Vec<std::path::PathBuf>,
|
||||||
|
dtype: candle_core::DType,
|
||||||
|
quant: Option<String>,
|
||||||
|
world_size: u32,
|
||||||
|
) -> Result<TpHandle, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::TpLoadShard {
|
||||||
|
model_id,
|
||||||
|
config_json,
|
||||||
|
safetensors_paths,
|
||||||
|
dtype,
|
||||||
|
quant,
|
||||||
|
world_size,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drop the TP model at `handle` on the worker thread.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn drop_tp(&self, handle: TpHandle) -> Result<(), WorkerError> {
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::DropTp {
|
||||||
|
handle,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(()) => Ok(()),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the leader's KV cache for a TP model.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn tp_clear_kv(&self, handle: TpHandle) -> Result<(), WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::TpClearKv {
|
||||||
|
handle,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run one TP forward step on the leader's shard. Returns CPU-side
|
||||||
|
/// logits as `Vec<f32>` ready for sampling. The caller is
|
||||||
|
/// responsible for fan-out / drain of the subprocess workers
|
||||||
|
/// concurrently with this call.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn tp_forward_logits(
|
||||||
|
&self,
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<Vec<f32>, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::TpForwardLogits {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
||||||
|
/// twice is a no-op the second time.
|
||||||
|
pub fn shutdown(&self) -> anyhow::Result<()> {
|
||||||
|
// Best-effort send: if the channel is already closed (thread
|
||||||
|
// exited after a prior shutdown or panic) the send fails and
|
||||||
|
// we fall through to the join which returns the panic, if any.
|
||||||
|
let _ = self.tx.send(Job::Shutdown);
|
||||||
|
let join = self.join.lock().unwrap().take();
|
||||||
|
if let Some(j) = join {
|
||||||
|
j.join()
|
||||||
|
.map_err(|_| anyhow::anyhow!("worker thread panicked during shutdown"))?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DeviceWorkerHandle {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// Best-effort: send Shutdown so the thread breaks its loop
|
||||||
|
// and exits. We do NOT join here — Drop may run on a tokio
|
||||||
|
// worker thread, and joining a thread that's still processing
|
||||||
|
// the last job would block the runtime. The OS reaps the
|
||||||
|
// thread on detach.
|
||||||
|
let _ = self.tx.send(Job::Shutdown);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn spawn_query_vram_shutdown() {
|
||||||
|
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||||
|
// CPU build (the only one CI runs) returns (0, 0) by design;
|
||||||
|
// a CUDA build with a real device would return real values.
|
||||||
|
let result = handle.query_vram().await.expect("query ok");
|
||||||
|
// We assert >= 0 — the field width matters more than the value.
|
||||||
|
let _ = result.0;
|
||||||
|
let _ = result.1;
|
||||||
|
handle.shutdown().expect("shutdown ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn thread_is_named_correctly() {
|
||||||
|
// The thread name lets `top -H` / pidstat / gdb show
|
||||||
|
// `cuda-dev-N` instead of an opaque tokio worker name. Verify
|
||||||
|
// by spawning and reading proc-self thread comms — but on
|
||||||
|
// platforms without /proc, just confirm we don't crash.
|
||||||
|
let handle = DeviceWorkerHandle::spawn(7).expect("spawn ok");
|
||||||
|
// Round-trip a job to ensure the thread is alive and processing.
|
||||||
|
handle.query_vram().await.expect("query ok");
|
||||||
|
handle.shutdown().expect("shutdown ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn submit_after_shutdown_returns_gone() {
|
||||||
|
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||||
|
handle.shutdown().expect("shutdown ok");
|
||||||
|
// Channel closed; submit should map to Gone rather than block.
|
||||||
|
let result = handle.query_vram().await;
|
||||||
|
match result {
|
||||||
|
Err(WorkerError::Gone { device_index: 0 }) => {}
|
||||||
|
other => panic!("expected Gone, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn poisoned_flag_short_circuits_submit() {
|
||||||
|
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||||
|
handle.set_poisoned();
|
||||||
|
let result = handle.query_vram().await;
|
||||||
|
match result {
|
||||||
|
Err(WorkerError::Poisoned { device_index: 0 }) => {}
|
||||||
|
other => panic!("expected Poisoned, got {other:?}"),
|
||||||
|
}
|
||||||
|
// The channel is still alive; shutdown should still succeed.
|
||||||
|
handle.shutdown().expect("shutdown ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn shutdown_drains_pending_jobs() {
|
||||||
|
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||||
|
// Submit many concurrent jobs; they should all complete even
|
||||||
|
// though a Shutdown is racing them.
|
||||||
|
let mut futures = Vec::new();
|
||||||
|
for _ in 0..16 {
|
||||||
|
let h = Arc::clone(&handle);
|
||||||
|
futures.push(tokio::spawn(async move { h.query_vram().await }));
|
||||||
|
}
|
||||||
|
// Small yield to give the senders a chance to actually send
|
||||||
|
// before we issue the shutdown; not strictly necessary because
|
||||||
|
// the channel is FIFO, but makes the test's intent clearer.
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
handle.shutdown().expect("shutdown ok");
|
||||||
|
for f in futures {
|
||||||
|
// Each query should have completed (Ok or Gone, never panic).
|
||||||
|
let _ = f.await.expect("task did not panic");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
130
crates/neuron/src/harness/mod.rs
Normal file
130
crates/neuron/src/harness/mod.rs
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
//! Harness registry — maps harness names to trait implementations.
|
||||||
|
|
||||||
|
pub mod arch;
|
||||||
|
pub mod candle;
|
||||||
|
pub mod chat_template;
|
||||||
|
pub mod device_worker;
|
||||||
|
pub mod preflight;
|
||||||
|
pub mod tp;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Registry of available harness implementations.
|
||||||
|
///
|
||||||
|
/// Holds an `Arc<dyn Harness>` per harness for generic lifecycle dispatch
|
||||||
|
/// (load/unload/list_models). When a candle harness is registered, a typed
|
||||||
|
/// `Arc<CandleHarness>` is also cached so inference routes can bypass the
|
||||||
|
/// dyn-Trait dispatch and reach harness-specific methods (chat completion,
|
||||||
|
/// streaming, etc.).
|
||||||
|
pub struct HarnessRegistry {
|
||||||
|
harnesses: HashMap<String, Arc<dyn Harness>>,
|
||||||
|
candle: Option<Arc<candle::CandleHarness>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for HarnessRegistry {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HarnessRegistry {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
harnesses: HashMap::new(),
|
||||||
|
candle: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register(&mut self, harness: Arc<dyn Harness>) {
|
||||||
|
self.harnesses.insert(harness.name().to_string(), harness);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all registered harness names.
|
||||||
|
pub fn names(&self) -> Vec<String> {
|
||||||
|
self.harnesses.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Typed handle to the candle harness, if registered. Used by inference
|
||||||
|
/// routes that need methods beyond the `Harness` trait surface.
|
||||||
|
pub fn candle(&self) -> Option<Arc<candle::CandleHarness>> {
|
||||||
|
self.candle.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List models from all registered harnesses.
|
||||||
|
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||||
|
let mut all = Vec::new();
|
||||||
|
for harness in self.harnesses.values() {
|
||||||
|
match harness.list_models().await {
|
||||||
|
Ok(models) => all.extend(models),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(harness = harness.name(), error = %e, "failed to list models");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(all)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a model on the specified harness.
|
||||||
|
pub async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
||||||
|
let harness = self
|
||||||
|
.harnesses
|
||||||
|
.get(&spec.harness)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("unknown harness: {}", spec.harness))?;
|
||||||
|
harness.load_model(spec).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unload a model. Tries each harness until one claims it.
|
||||||
|
pub async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||||||
|
for harness in self.harnesses.values() {
|
||||||
|
match harness.list_models().await {
|
||||||
|
Ok(models) if models.iter().any(|m| m.id == model_id) => {
|
||||||
|
return harness.unload_model(model_id).await;
|
||||||
|
}
|
||||||
|
_ => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow::bail!("model '{model_id}' not found on any harness")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the inference endpoint for a model.
|
||||||
|
pub async fn inference_endpoint(&self, model_id: &str) -> Option<String> {
|
||||||
|
for harness in self.harnesses.values() {
|
||||||
|
if let Some(url) = harness.inference_endpoint(model_id).await {
|
||||||
|
return Some(url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a registry from harness configs.
|
||||||
|
///
|
||||||
|
/// `bind_url` is the URL where this neuron serves inference (its own
|
||||||
|
/// listen address). In-process harnesses (currently the only kind)
|
||||||
|
/// return this URL from `inference_endpoint`.
|
||||||
|
pub fn from_configs(
|
||||||
|
configs: &[HarnessConfig],
|
||||||
|
bind_url: &str,
|
||||||
|
settings: &crate::config::HarnessSettings,
|
||||||
|
) -> Self {
|
||||||
|
let mut registry = Self::new();
|
||||||
|
for config in configs {
|
||||||
|
match config.name.as_str() {
|
||||||
|
"candle" => {
|
||||||
|
let harness = Arc::new(candle::CandleHarness::new(
|
||||||
|
bind_url.to_string(),
|
||||||
|
settings.candle.hf_cache.clone(),
|
||||||
|
));
|
||||||
|
registry.candle = Some(Arc::clone(&harness));
|
||||||
|
registry.harnesses.insert("candle".into(), harness);
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
tracing::warn!(harness = other, "unknown harness type, skipping");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
registry
|
||||||
|
}
|
||||||
|
}
|
||||||
575
crates/neuron/src/harness/preflight.rs
Normal file
575
crates/neuron/src/harness/preflight.rs
Normal file
@@ -0,0 +1,575 @@
|
|||||||
|
//! Placement feasibility check that runs before any device allocation,
|
||||||
|
//! NCCL handshake, or weight download.
|
||||||
|
//!
|
||||||
|
//! The loader path in `candle.rs` historically discovers an
|
||||||
|
//! incompatibility *after* it has already started fetching files —
|
||||||
|
//! "fetch config.json from HauhauCS/...: 404 Not Found" surfaces hours
|
||||||
|
//! after operators set `tensor_parallel = 2` on a GGUF-only repo, with
|
||||||
|
//! no hint about what's actually wrong. Preflight closes that gap:
|
||||||
|
//!
|
||||||
|
//! 1. one `repo.info()` round-trip (siblings listing, no blob fetch)
|
||||||
|
//! 2. classify the repo: GGUF-only, dense safetensors, mixed, empty
|
||||||
|
//! 3. apply the feasibility table against the requested
|
||||||
|
//! `ModelSpec` (tp_size, quant)
|
||||||
|
//! 4. return a structured `PreflightError` the API layer can map to
|
||||||
|
//! 422 + JSON, or `Ok(PlacementPlan)` carrying the decisions the
|
||||||
|
//! downstream load path needs (which GGUF file to fetch, etc.).
|
||||||
|
//!
|
||||||
|
//! Phase 2 of plan-source-aware-loader-preflight. The Phase 1 scheme
|
||||||
|
//! work — `ModelSourceId` and per-scheme `SourceConfig` — is a
|
||||||
|
//! separate PR; preflight runs against the single configured
|
||||||
|
//! HuggingFace source for now and the scheme threading drops in
|
||||||
|
//! cleanly when Phase 1 lands.
|
||||||
|
|
||||||
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use hf_hub::api::tokio::Api;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
/// What the repo's siblings listing tells us about how to load it.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||||
|
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||||
|
pub enum SourceFormat {
|
||||||
|
/// Only GGUF files present. Single-GPU load path. `quants` is the
|
||||||
|
/// lowercased filename list so the operator can be told what's
|
||||||
|
/// actually available when their `quant=` choice doesn't match.
|
||||||
|
Gguf { quants: Vec<String> },
|
||||||
|
/// Dense safetensors (single-file or sharded via index.json).
|
||||||
|
/// Goes through `load_arch_dense` on single-GPU, or `load_tp` (with
|
||||||
|
/// optional in-situ quantization) when `tensor_parallel > 1`.
|
||||||
|
DenseSafetensors { sharded: bool },
|
||||||
|
/// Both safetensors and GGUF present — prefer the dense path
|
||||||
|
/// because it composes with TP and ISQ. We surface the GGUF
|
||||||
|
/// filenames anyway so operators with a strong preference can
|
||||||
|
/// see they exist.
|
||||||
|
Mixed { gguf_quants: Vec<String> },
|
||||||
|
/// No recognised weight files. Either a tokenizer-only repo
|
||||||
|
/// (e.g. some base-model repos that only host `tokenizer.json` and
|
||||||
|
/// expect the operator to use a `-GGUF` sibling repo) or a
|
||||||
|
/// genuinely empty entry.
|
||||||
|
Empty,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Output of `preflight` for a load that can proceed. Carries the
|
||||||
|
/// decisions downstream resolve_* paths would otherwise re-derive.
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct PlacementPlan {
|
||||||
|
pub model_id: String,
|
||||||
|
pub format: SourceFormat,
|
||||||
|
pub tp_size: u32,
|
||||||
|
/// Filename of the GGUF to fetch, populated when `format` is
|
||||||
|
/// `Gguf` and a single-GPU load was requested. None for the
|
||||||
|
/// dense/TP path.
|
||||||
|
pub picked_quant_file: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Structured failure modes. Each variant carries the fields the API
|
||||||
|
/// layer needs to produce an actionable 422 body.
|
||||||
|
#[derive(Debug, Clone, Serialize, thiserror::Error)]
|
||||||
|
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||||
|
pub enum PreflightError {
|
||||||
|
/// `repo.info()` failed. Captures the underlying cause as a string
|
||||||
|
/// so the operator log shows whether it's auth, 404, or transport.
|
||||||
|
#[error("failed to fetch repo info for '{model_id}': {cause}")]
|
||||||
|
RepoFetchFailed { model_id: String, cause: String },
|
||||||
|
|
||||||
|
/// The repo exists but has no recognised weight files.
|
||||||
|
#[error(
|
||||||
|
"repo '{model_id}' has no recognised weight files (no .gguf, no .safetensors); \
|
||||||
|
a tokenizer-only repo cannot be loaded directly"
|
||||||
|
)]
|
||||||
|
EmptyRepo { model_id: String },
|
||||||
|
|
||||||
|
/// Operator asked for `tensor_parallel > 1` on a GGUF-only repo.
|
||||||
|
/// The TP path requires safetensors+config for in-situ
|
||||||
|
/// quantization; GGUF-TP isn't implemented (see CLAUDE.md).
|
||||||
|
#[error(
|
||||||
|
"cannot load '{model_id}' with tensor_parallel={tp_size}: repo is GGUF-only \
|
||||||
|
({} .gguf files); TP requires dense safetensors. {suggestion}",
|
||||||
|
gguf_quants.len()
|
||||||
|
)]
|
||||||
|
TpRequiresSafetensors {
|
||||||
|
model_id: String,
|
||||||
|
tp_size: u32,
|
||||||
|
gguf_quants: Vec<String>,
|
||||||
|
suggestion: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Operator asked for a GGUF quant whose substring doesn't match
|
||||||
|
/// any filename in the repo. `nearest` is a best-effort Levenshtein
|
||||||
|
/// suggestion against the available quant names.
|
||||||
|
#[error(
|
||||||
|
"no GGUF file in '{model_id}' matches quant '{requested}'; \
|
||||||
|
available: {available:?}{}",
|
||||||
|
nearest.as_ref().map(|n| format!("; did you mean '{n}'?")).unwrap_or_default()
|
||||||
|
)]
|
||||||
|
QuantNotFound {
|
||||||
|
model_id: String,
|
||||||
|
requested: String,
|
||||||
|
available: Vec<String>,
|
||||||
|
nearest: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the placement check.
|
||||||
|
///
|
||||||
|
/// One network round-trip (`repo.info()`); no blob fetches. Returns
|
||||||
|
/// `Ok(PlacementPlan)` when the requested combination is feasible, or
|
||||||
|
/// a structured `PreflightError` describing what's wrong.
|
||||||
|
pub async fn preflight(api: &Api, spec: &ModelSpec) -> Result<PlacementPlan, PreflightError> {
|
||||||
|
let repo = api.model(spec.model_id.clone());
|
||||||
|
let info = repo
|
||||||
|
.info()
|
||||||
|
.await
|
||||||
|
.map_err(|e| PreflightError::RepoFetchFailed {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
cause: format!("{e}"),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let filenames: Vec<&str> = info.siblings.iter().map(|s| s.rfilename.as_str()).collect();
|
||||||
|
let format = classify(&filenames);
|
||||||
|
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||||
|
|
||||||
|
match (&format, tp_size, spec.quant.as_deref()) {
|
||||||
|
// No weights at all — nothing to do.
|
||||||
|
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
// GGUF-only + TP: not supported. Today's HauhauCS failure.
|
||||||
|
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
|
||||||
|
Err(PreflightError::TpRequiresSafetensors {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
tp_size: tp,
|
||||||
|
gguf_quants: quants.clone(),
|
||||||
|
suggestion: format!(
|
||||||
|
"Set tensor_parallel=1 and pick a quant from {quants:?}, \
|
||||||
|
or use a dense safetensors release of this model."
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GGUF-only + single-GPU: pick the file that matches the
|
||||||
|
// operator's quant. Empty quant matches the first GGUF.
|
||||||
|
(SourceFormat::Gguf { quants }, _, requested) => {
|
||||||
|
let picked = pick_gguf_file(&filenames, requested.unwrap_or(""));
|
||||||
|
match picked {
|
||||||
|
Some(fname) => Ok(PlacementPlan {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
format: format.clone(),
|
||||||
|
tp_size,
|
||||||
|
picked_quant_file: Some(fname),
|
||||||
|
}),
|
||||||
|
None => Err(PreflightError::QuantNotFound {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
requested: requested.unwrap_or("").to_string(),
|
||||||
|
available: quants.clone(),
|
||||||
|
nearest: nearest_quant(requested.unwrap_or(""), quants),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dense or mixed: dense path handles both single-GPU and TP.
|
||||||
|
// The architecture compatibility check stays where it is —
|
||||||
|
// `check_dense_config_supported` runs once `config.json` is
|
||||||
|
// on disk, since it needs the parsed JSON.
|
||||||
|
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
|
||||||
|
Ok(PlacementPlan {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
format: format.clone(),
|
||||||
|
tp_size,
|
||||||
|
picked_quant_file: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Classify a siblings file list into a `SourceFormat`. Pulled out so
|
||||||
|
/// the unit tests can exercise it against fixture JSON without
|
||||||
|
/// spinning up an Api.
|
||||||
|
pub fn classify(filenames: &[&str]) -> SourceFormat {
|
||||||
|
let mut gguf_quants: Vec<String> = filenames
|
||||||
|
.iter()
|
||||||
|
.filter(|f| f.to_lowercase().ends_with(".gguf"))
|
||||||
|
.map(|f| f.to_lowercase())
|
||||||
|
.collect();
|
||||||
|
gguf_quants.sort();
|
||||||
|
gguf_quants.dedup();
|
||||||
|
|
||||||
|
let has_safetensors = filenames.iter().any(|f| f.ends_with(".safetensors"));
|
||||||
|
let sharded = filenames
|
||||||
|
.iter()
|
||||||
|
.any(|f| f.ends_with("model.safetensors.index.json"));
|
||||||
|
|
||||||
|
match (has_safetensors, gguf_quants.is_empty()) {
|
||||||
|
(true, true) => SourceFormat::DenseSafetensors { sharded },
|
||||||
|
(true, false) => SourceFormat::Mixed { gguf_quants },
|
||||||
|
(false, false) => SourceFormat::Gguf {
|
||||||
|
quants: gguf_quants,
|
||||||
|
},
|
||||||
|
(false, true) => SourceFormat::Empty,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mirror of the quant-matching logic in `candle.rs::resolve_files` so
|
||||||
|
/// preflight picks the same file the downstream loader would. Empty
|
||||||
|
/// quant returns the first `.gguf` (any quant). Lowercased substring
|
||||||
|
/// match otherwise.
|
||||||
|
fn pick_gguf_file(filenames: &[&str], quant_lc: &str) -> Option<String> {
|
||||||
|
filenames
|
||||||
|
.iter()
|
||||||
|
.filter(|f| f.to_lowercase().ends_with(".gguf"))
|
||||||
|
.find(|f| quant_lc.is_empty() || f.to_lowercase().contains(quant_lc))
|
||||||
|
.map(|f| f.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Best-effort suggestion when the operator's quant name doesn't
|
||||||
|
/// substring-match any filename. Extracts the quant-ish token from
|
||||||
|
/// each `.gguf` filename and picks the one with the smallest
|
||||||
|
/// Levenshtein distance to the requested string. Returns None when
|
||||||
|
/// the input is empty or no candidates exist.
|
||||||
|
fn nearest_quant(requested: &str, candidates: &[String]) -> Option<String> {
|
||||||
|
if requested.is_empty() || candidates.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
// Pull the "Q6_K_P"/"IQ4_XS"-ish token out of each filename for a
|
||||||
|
// fairer comparison. Filenames look like
|
||||||
|
// `Qwen3.6-27B-Uncensored-HauhauCS-Aggressive-Q6_K_P.gguf`, so the
|
||||||
|
// quant is the last `-`-separated segment before the extension,
|
||||||
|
// lowercased.
|
||||||
|
let tokens: Vec<(String, String)> = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|f| (extract_quant_token(f), f.clone()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let req_lc = requested.to_lowercase();
|
||||||
|
tokens
|
||||||
|
.into_iter()
|
||||||
|
.min_by_key(|(token, _)| levenshtein(&req_lc, token))
|
||||||
|
.map(|(token, _)| token)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_quant_token(filename: &str) -> String {
|
||||||
|
let stem = filename
|
||||||
|
.rsplit_once('.')
|
||||||
|
.map(|(s, _)| s)
|
||||||
|
.unwrap_or(filename);
|
||||||
|
let token = stem.rsplit('-').next().unwrap_or(stem);
|
||||||
|
token.to_lowercase()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterative Levenshtein. Small inputs (quant names are <=12 chars),
|
||||||
|
/// no need for the `levenshtein` crate.
|
||||||
|
fn levenshtein(a: &str, b: &str) -> usize {
|
||||||
|
let a: Vec<char> = a.chars().collect();
|
||||||
|
let b: Vec<char> = b.chars().collect();
|
||||||
|
let (m, n) = (a.len(), b.len());
|
||||||
|
if m == 0 {
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return m;
|
||||||
|
}
|
||||||
|
let mut prev: Vec<usize> = (0..=n).collect();
|
||||||
|
let mut curr = vec![0usize; n + 1];
|
||||||
|
for i in 1..=m {
|
||||||
|
curr[0] = i;
|
||||||
|
for j in 1..=n {
|
||||||
|
let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
|
||||||
|
curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
|
||||||
|
}
|
||||||
|
std::mem::swap(&mut prev, &mut curr);
|
||||||
|
}
|
||||||
|
prev[n]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn spec(model_id: &str, tp: Option<u32>, quant: Option<&str>) -> ModelSpec {
|
||||||
|
ModelSpec {
|
||||||
|
model_id: model_id.into(),
|
||||||
|
harness: "candle".into(),
|
||||||
|
quant: quant.map(String::from),
|
||||||
|
tensor_parallel: tp,
|
||||||
|
devices: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn classify_gguf_only() {
|
||||||
|
let files = [
|
||||||
|
"README.md",
|
||||||
|
".gitattributes",
|
||||||
|
"Qwen3.6-27B-Q6_K_P.gguf",
|
||||||
|
"Qwen3.6-27B-Q4_K_P.gguf",
|
||||||
|
];
|
||||||
|
match classify(&files) {
|
||||||
|
SourceFormat::Gguf { quants } => {
|
||||||
|
assert_eq!(quants.len(), 2);
|
||||||
|
assert!(quants.iter().any(|q| q.contains("q6_k_p")));
|
||||||
|
}
|
||||||
|
other => panic!("expected Gguf, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn classify_dense_sharded() {
|
||||||
|
let files = [
|
||||||
|
"config.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
"model-00001-of-00002.safetensors",
|
||||||
|
"model-00002-of-00002.safetensors",
|
||||||
|
];
|
||||||
|
assert_eq!(
|
||||||
|
classify(&files),
|
||||||
|
SourceFormat::DenseSafetensors { sharded: true }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn classify_dense_single_file() {
|
||||||
|
let files = ["config.json", "tokenizer.json", "model.safetensors"];
|
||||||
|
assert_eq!(
|
||||||
|
classify(&files),
|
||||||
|
SourceFormat::DenseSafetensors { sharded: false }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn classify_mixed() {
|
||||||
|
let files = [
|
||||||
|
"config.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
"model.safetensors",
|
||||||
|
"model-Q4_K_M.gguf",
|
||||||
|
];
|
||||||
|
match classify(&files) {
|
||||||
|
SourceFormat::Mixed { gguf_quants } => {
|
||||||
|
assert_eq!(gguf_quants, vec!["model-q4_k_m.gguf"]);
|
||||||
|
}
|
||||||
|
other => panic!("expected Mixed, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn classify_empty() {
|
||||||
|
let files = ["README.md", "tokenizer.json"];
|
||||||
|
assert_eq!(classify(&files), SourceFormat::Empty);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pick_gguf_substring_match() {
|
||||||
|
let files = ["model-Q4_K_M.gguf", "model-Q6_K.gguf", "model-Q8_0.gguf"];
|
||||||
|
assert_eq!(
|
||||||
|
pick_gguf_file(&files, "q6_k"),
|
||||||
|
Some("model-Q6_K.gguf".into())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pick_gguf_empty_returns_first() {
|
||||||
|
let files = ["model-Q4_K_M.gguf", "model-Q6_K.gguf"];
|
||||||
|
assert_eq!(pick_gguf_file(&files, ""), Some("model-Q4_K_M.gguf".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pick_gguf_no_match() {
|
||||||
|
let files = ["model-Q4_K_M.gguf", "model-Q6_K.gguf"];
|
||||||
|
assert_eq!(pick_gguf_file(&files, "iq2_xs"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn nearest_quant_suggests_close_match() {
|
||||||
|
// Today's HauhauCS scenario: operator wrote "q6k", actual
|
||||||
|
// filename token is "q6_k_p". Should suggest the latter.
|
||||||
|
let candidates = vec![
|
||||||
|
"qwen-q4_k_p.gguf".to_string(),
|
||||||
|
"qwen-q5_k_p.gguf".to_string(),
|
||||||
|
"qwen-q6_k_p.gguf".to_string(),
|
||||||
|
"qwen-q8_k_p.gguf".to_string(),
|
||||||
|
];
|
||||||
|
assert_eq!(nearest_quant("q6k", &candidates), Some("q6_k_p".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn nearest_quant_empty_input() {
|
||||||
|
assert_eq!(nearest_quant("", &[]), None);
|
||||||
|
assert_eq!(nearest_quant("q6k", &[]), None);
|
||||||
|
assert_eq!(nearest_quant("", &["model-q4.gguf".into()]), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_quant_handles_typical_filenames() {
|
||||||
|
assert_eq!(extract_quant_token("Qwen3.6-27B-Q6_K_P.gguf"), "q6_k_p");
|
||||||
|
assert_eq!(extract_quant_token("model-IQ4_XS.gguf"), "iq4_xs");
|
||||||
|
assert_eq!(extract_quant_token("simple.gguf"), "simple");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn levenshtein_basics() {
|
||||||
|
assert_eq!(levenshtein("", ""), 0);
|
||||||
|
assert_eq!(levenshtein("abc", ""), 3);
|
||||||
|
assert_eq!(levenshtein("", "abc"), 3);
|
||||||
|
assert_eq!(levenshtein("kitten", "sitting"), 3);
|
||||||
|
assert_eq!(levenshtein("q6k", "q6_k_p"), 3);
|
||||||
|
assert_eq!(levenshtein("q6k", "q4_k_p"), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Higher-level preflight tests below exercise the full feasibility
|
||||||
|
// table via a thin wrapper that bypasses the network — we hand it
|
||||||
|
// a pre-built `SourceFormat` and request shape, then drive the
|
||||||
|
// same decision logic. The end-to-end test with a mock HTTP
|
||||||
|
// server lives in tests/preflight.rs (integration).
|
||||||
|
|
||||||
|
/// Mirror of the `match` in `preflight()` but takes a classified
|
||||||
|
/// `SourceFormat` directly. Lets us unit-test the feasibility
|
||||||
|
/// table without making the API trait object-safe / boxable.
|
||||||
|
fn decide(
|
||||||
|
spec: &ModelSpec,
|
||||||
|
format: &SourceFormat,
|
||||||
|
filenames: &[&str],
|
||||||
|
) -> Result<PlacementPlan, PreflightError> {
|
||||||
|
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||||
|
match (format, tp_size, spec.quant.as_deref()) {
|
||||||
|
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
}),
|
||||||
|
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
|
||||||
|
Err(PreflightError::TpRequiresSafetensors {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
tp_size: tp,
|
||||||
|
gguf_quants: quants.clone(),
|
||||||
|
suggestion: format!(
|
||||||
|
"Set tensor_parallel=1 and pick a quant from {quants:?}, \
|
||||||
|
or use a dense safetensors release of this model."
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
(SourceFormat::Gguf { quants }, _, requested) => {
|
||||||
|
let picked = pick_gguf_file(filenames, requested.unwrap_or(""));
|
||||||
|
match picked {
|
||||||
|
Some(fname) => Ok(PlacementPlan {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
format: format.clone(),
|
||||||
|
tp_size,
|
||||||
|
picked_quant_file: Some(fname),
|
||||||
|
}),
|
||||||
|
None => Err(PreflightError::QuantNotFound {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
requested: requested.unwrap_or("").to_string(),
|
||||||
|
available: quants.clone(),
|
||||||
|
nearest: nearest_quant(requested.unwrap_or(""), quants),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
|
||||||
|
Ok(PlacementPlan {
|
||||||
|
model_id: spec.model_id.clone(),
|
||||||
|
format: format.clone(),
|
||||||
|
tp_size,
|
||||||
|
picked_quant_file: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn feasibility_gguf_tp_rejected() {
|
||||||
|
let files = ["Qwen-Q6_K_P.gguf", "Qwen-Q4_K_P.gguf"];
|
||||||
|
let fmt = classify(&files);
|
||||||
|
let s = spec("HauhauCS/Qwen3.6", Some(2), Some("q6k"));
|
||||||
|
match decide(&s, &fmt, &files).unwrap_err() {
|
||||||
|
PreflightError::TpRequiresSafetensors {
|
||||||
|
model_id,
|
||||||
|
tp_size,
|
||||||
|
gguf_quants,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(model_id, "HauhauCS/Qwen3.6");
|
||||||
|
assert_eq!(tp_size, 2);
|
||||||
|
assert_eq!(gguf_quants.len(), 2);
|
||||||
|
}
|
||||||
|
other => panic!("expected TpRequiresSafetensors, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn feasibility_gguf_single_gpu_bad_quant() {
|
||||||
|
let files = [
|
||||||
|
"Qwen-Q4_K_P.gguf",
|
||||||
|
"Qwen-Q5_K_P.gguf",
|
||||||
|
"Qwen-Q6_K_P.gguf",
|
||||||
|
"Qwen-Q8_K_P.gguf",
|
||||||
|
];
|
||||||
|
let fmt = classify(&files);
|
||||||
|
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6k"));
|
||||||
|
match decide(&s, &fmt, &files).unwrap_err() {
|
||||||
|
PreflightError::QuantNotFound {
|
||||||
|
requested,
|
||||||
|
nearest,
|
||||||
|
available,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(requested, "q6k");
|
||||||
|
assert_eq!(nearest.as_deref(), Some("q6_k_p"));
|
||||||
|
assert_eq!(available.len(), 4);
|
||||||
|
}
|
||||||
|
other => panic!("expected QuantNotFound, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn feasibility_gguf_single_gpu_good_quant() {
|
||||||
|
let files = ["Qwen-Q4_K_M.gguf", "Qwen-Q6_K.gguf"];
|
||||||
|
let fmt = classify(&files);
|
||||||
|
let s = spec("Qwen/Q-GGUF", Some(1), Some("q6_k"));
|
||||||
|
let plan = decide(&s, &fmt, &files).unwrap();
|
||||||
|
assert_eq!(plan.picked_quant_file.as_deref(), Some("Qwen-Q6_K.gguf"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn feasibility_dense_tp_ok() {
|
||||||
|
let files = [
|
||||||
|
"config.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
"model-00001-of-00002.safetensors",
|
||||||
|
];
|
||||||
|
let fmt = classify(&files);
|
||||||
|
let s = spec("Qwen/Q3-30B", Some(2), Some("q5k"));
|
||||||
|
let plan = decide(&s, &fmt, &files).unwrap();
|
||||||
|
assert_eq!(plan.tp_size, 2);
|
||||||
|
assert!(plan.picked_quant_file.is_none());
|
||||||
|
assert!(matches!(
|
||||||
|
plan.format,
|
||||||
|
SourceFormat::DenseSafetensors { sharded: true }
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn feasibility_empty_rejected() {
|
||||||
|
let files = ["README.md", "tokenizer.json"];
|
||||||
|
let fmt = classify(&files);
|
||||||
|
let s = spec("Empty/Repo", Some(1), None);
|
||||||
|
match decide(&s, &fmt, &files).unwrap_err() {
|
||||||
|
PreflightError::EmptyRepo { model_id } => assert_eq!(model_id, "Empty/Repo"),
|
||||||
|
other => panic!("expected EmptyRepo, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn error_serialization_carries_kind_field() {
|
||||||
|
let err = PreflightError::TpRequiresSafetensors {
|
||||||
|
model_id: "x/y".into(),
|
||||||
|
tp_size: 2,
|
||||||
|
gguf_quants: vec!["q6_k_p".into()],
|
||||||
|
suggestion: "...".into(),
|
||||||
|
};
|
||||||
|
let v: serde_json::Value = serde_json::to_value(&err).unwrap();
|
||||||
|
assert_eq!(v["kind"], "tp_requires_safetensors");
|
||||||
|
assert_eq!(v["model_id"], "x/y");
|
||||||
|
assert_eq!(v["tp_size"], 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
119
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
119
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
|
||||||
|
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
|
||||||
|
//!
|
||||||
|
//! Ported from the canonical
|
||||||
|
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
|
||||||
|
//! Row-parallel layers apply this op after their local matmul to sum
|
||||||
|
//! partial outputs across NCCL ranks.
|
||||||
|
//!
|
||||||
|
//! Available only under `--features cuda`; on CPU builds this module
|
||||||
|
//! is empty and row-parallel layers degenerate to local matmul only
|
||||||
|
//! (useful for compile-checking the model code; correctness requires
|
||||||
|
//! cuda).
|
||||||
|
//!
|
||||||
|
//! Thread-safety caveat: NCCL communicators are technically only
|
||||||
|
//! safe to use from a single thread at a time
|
||||||
|
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
|
||||||
|
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
|
||||||
|
//! against it from the dedicated `spawn_blocking` thread the inference
|
||||||
|
//! pipeline already uses for candle's forward passes.
|
||||||
|
|
||||||
|
#![cfg(feature = "cuda")]
|
||||||
|
|
||||||
|
use candle_core::backend::BackendStorage;
|
||||||
|
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
|
||||||
|
use cudarc::nccl::{Comm, ReduceOp};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
|
||||||
|
/// graph as a custom op. Each row-parallel layer holds one of these.
|
||||||
|
pub struct AllReduce {
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
|
||||||
|
// that issuing ops against one comm from multiple threads concurrently
|
||||||
|
// is unsafe. We serialise via the single spawn_blocking thread that
|
||||||
|
// drives the model's forward pass. The Send/Sync impl is necessary
|
||||||
|
// because candle's CustomOp1 trait bounds require it; the correctness
|
||||||
|
// invariant is enforced at the call site, not the type level.
|
||||||
|
unsafe impl Send for AllReduce {}
|
||||||
|
unsafe impl Sync for AllReduce {}
|
||||||
|
|
||||||
|
impl AllReduce {
|
||||||
|
pub fn new(comm: Arc<Comm>) -> Self {
|
||||||
|
Self { comm }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn comm(&self) -> &Arc<Comm> {
|
||||||
|
&self.comm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomOp1 for AllReduce {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"neuron.tp.all_reduce"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
|
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||||
|
// Reject non-contiguous inputs explicitly — copying them
|
||||||
|
// server-side would mask shape bugs (a TP layer feeding a
|
||||||
|
// strided activation into all_reduce is almost certainly a
|
||||||
|
// model construction error).
|
||||||
|
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
|
||||||
|
slice: &cudarc::driver::CudaSlice<T>,
|
||||||
|
l: &Layout,
|
||||||
|
) -> Result<()> {
|
||||||
|
match l.contiguous_offsets() {
|
||||||
|
Some((0, n)) if n == slice.len() => Ok(()),
|
||||||
|
_ => candle_core::bail!(
|
||||||
|
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
|
||||||
|
l,
|
||||||
|
slice.len()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let elem_count = l.shape().elem_count();
|
||||||
|
let dev = s.device().clone();
|
||||||
|
|
||||||
|
let out = match s.dtype() {
|
||||||
|
DType::BF16 => {
|
||||||
|
let src = s.as_cuda_slice::<bf16>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let src = s.as_cuda_slice::<f16>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let src = s.as_cuda_slice::<f32>()?;
|
||||||
|
require_contiguous(src, l)?;
|
||||||
|
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||||
|
self.comm
|
||||||
|
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
|
||||||
|
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
|
}
|
||||||
|
dtype => candle_core::bail!(
|
||||||
|
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
Ok((out, l.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
213
crates/neuron/src/harness/tp/fused_load.rs
Normal file
213
crates/neuron/src/harness/tp/fused_load.rs
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
//! Direct safetensors readers for fused-region weight tensors.
|
||||||
|
//!
|
||||||
|
//! Qwen3-Next's `in_proj_qkv` and `conv1d` weights are *fused* —
|
||||||
|
//! three regions stored sequentially along dim 0 (`[key_q, key_k,
|
||||||
|
//! value]`). The per-rank shard for each region has unequal size
|
||||||
|
//! (`key_dim/ws` vs `value_dim/ws`), so candle's `ShardedSafeTensors`
|
||||||
|
//! built-in `Shard { dim, rank, world_size }` (uniform split) doesn't
|
||||||
|
//! map to the right slices.
|
||||||
|
//!
|
||||||
|
//! The previous approach loaded the full fused tensor onto the device,
|
||||||
|
//! `narrow`ed the three regions, and `Tensor::cat(...).contiguous()`'d
|
||||||
|
//! the per-rank slice. That left ~100 MB of transient device memory
|
||||||
|
//! per linear-attention layer — 48 layers × 100 MB = ~4.8 GB of
|
||||||
|
//! allocator pressure during load, enough to trigger fragmentation
|
||||||
|
//! OOM on tight-VRAM consumer GPUs.
|
||||||
|
//!
|
||||||
|
//! This module reads the three per-rank byte ranges *directly from
|
||||||
|
//! the safetensors mmap* (host-side), concatenates them into a single
|
||||||
|
//! contiguous byte buffer, and uploads as one device allocation. No
|
||||||
|
//! full-tensor device materialisation.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, bail};
|
||||||
|
use candle_core::safetensors::MmapedSafetensors;
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
|
||||||
|
/// Read a 2D fused-QKV tensor `[conv_dim, hidden_size]` and return
|
||||||
|
/// this rank's per-region slice as a `[per_rank_conv_dim, hidden_size]`
|
||||||
|
/// device tensor.
|
||||||
|
///
|
||||||
|
/// `tensor_name` must be the fully-qualified safetensors key (e.g.
|
||||||
|
/// `"model.language_model.layers.5.linear_attn.in_proj_qkv.weight"`).
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn load_fused_qkv_2d(
|
||||||
|
mmap: &MmapedSafetensors,
|
||||||
|
tensor_name: &str,
|
||||||
|
hidden_size: usize,
|
||||||
|
key_dim: usize,
|
||||||
|
value_dim: usize,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
target_dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let ws = world_size as usize;
|
||||||
|
let r = rank as usize;
|
||||||
|
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||||
|
must each be divisible by world_size ({ws})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let per_rank_key = key_dim / ws;
|
||||||
|
let per_rank_value = value_dim / ws;
|
||||||
|
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
|
||||||
|
|
||||||
|
let view = mmap
|
||||||
|
.get(tensor_name)
|
||||||
|
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 2D"))?;
|
||||||
|
let view_dtype: DType = view
|
||||||
|
.dtype()
|
||||||
|
.try_into()
|
||||||
|
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
|
||||||
|
|
||||||
|
let shape = view.shape();
|
||||||
|
if shape.len() != 2 {
|
||||||
|
bail!(
|
||||||
|
"fused qkv tensor '{tensor_name}' has shape {shape:?}, expected 2D \
|
||||||
|
[conv_dim, hidden_size]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let conv_dim = key_dim * 2 + value_dim;
|
||||||
|
if shape[0] != conv_dim || shape[1] != hidden_size {
|
||||||
|
bail!(
|
||||||
|
"fused qkv tensor '{tensor_name}' shape {shape:?} \
|
||||||
|
doesn't match expected [{conv_dim}, {hidden_size}]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
|
||||||
|
let k_bytes = slice_dim0_bytes(
|
||||||
|
&view,
|
||||||
|
key_dim + r * per_rank_key,
|
||||||
|
per_rank_key,
|
||||||
|
tensor_name,
|
||||||
|
"k",
|
||||||
|
)?;
|
||||||
|
let v_bytes = slice_dim0_bytes(
|
||||||
|
&view,
|
||||||
|
2 * key_dim + r * per_rank_value,
|
||||||
|
per_rank_value,
|
||||||
|
tensor_name,
|
||||||
|
"v",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
|
||||||
|
bytes.extend_from_slice(&q_bytes);
|
||||||
|
bytes.extend_from_slice(&k_bytes);
|
||||||
|
bytes.extend_from_slice(&v_bytes);
|
||||||
|
|
||||||
|
let tensor = Tensor::from_raw_buffer(
|
||||||
|
&bytes,
|
||||||
|
view_dtype,
|
||||||
|
&[per_rank_conv_dim, hidden_size],
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused qkv '{tensor_name}'"))?;
|
||||||
|
tensor
|
||||||
|
.to_dtype(target_dtype)
|
||||||
|
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read a 3D fused-QKV tensor `[conv_dim, 1, kernel_size]` (the
|
||||||
|
/// depthwise conv1d weight) and return this rank's per-region slice
|
||||||
|
/// as a `[per_rank_conv_dim, 1, kernel_size]` device tensor.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn load_fused_qkv_3d(
|
||||||
|
mmap: &MmapedSafetensors,
|
||||||
|
tensor_name: &str,
|
||||||
|
mid: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
key_dim: usize,
|
||||||
|
value_dim: usize,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
target_dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let ws = world_size as usize;
|
||||||
|
let r = rank as usize;
|
||||||
|
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||||
|
must each be divisible by world_size ({ws})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let per_rank_key = key_dim / ws;
|
||||||
|
let per_rank_value = value_dim / ws;
|
||||||
|
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
|
||||||
|
|
||||||
|
let view = mmap
|
||||||
|
.get(tensor_name)
|
||||||
|
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 3D"))?;
|
||||||
|
let view_dtype: DType = view
|
||||||
|
.dtype()
|
||||||
|
.try_into()
|
||||||
|
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
|
||||||
|
|
||||||
|
let shape = view.shape();
|
||||||
|
if shape.len() != 3 {
|
||||||
|
bail!(
|
||||||
|
"fused conv tensor '{tensor_name}' has shape {shape:?}, expected 3D \
|
||||||
|
[conv_dim, mid, kernel_size]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let conv_dim = key_dim * 2 + value_dim;
|
||||||
|
if shape[0] != conv_dim || shape[1] != mid || shape[2] != kernel_size {
|
||||||
|
bail!(
|
||||||
|
"fused conv tensor '{tensor_name}' shape {shape:?} \
|
||||||
|
doesn't match expected [{conv_dim}, {mid}, {kernel_size}]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
|
||||||
|
let k_bytes = slice_dim0_bytes(
|
||||||
|
&view,
|
||||||
|
key_dim + r * per_rank_key,
|
||||||
|
per_rank_key,
|
||||||
|
tensor_name,
|
||||||
|
"k",
|
||||||
|
)?;
|
||||||
|
let v_bytes = slice_dim0_bytes(
|
||||||
|
&view,
|
||||||
|
2 * key_dim + r * per_rank_value,
|
||||||
|
per_rank_value,
|
||||||
|
tensor_name,
|
||||||
|
"v",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
|
||||||
|
bytes.extend_from_slice(&q_bytes);
|
||||||
|
bytes.extend_from_slice(&k_bytes);
|
||||||
|
bytes.extend_from_slice(&v_bytes);
|
||||||
|
|
||||||
|
let tensor = Tensor::from_raw_buffer(
|
||||||
|
&bytes,
|
||||||
|
view_dtype,
|
||||||
|
&[per_rank_conv_dim, mid, kernel_size],
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused conv '{tensor_name}'"))?;
|
||||||
|
tensor
|
||||||
|
.to_dtype(target_dtype)
|
||||||
|
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read `len` consecutive rows along dim 0 starting at `start` from
|
||||||
|
/// the safetensors view, returning the raw bytes. Wraps the same
|
||||||
|
/// `view.slice(start..stop)` machinery that candle's
|
||||||
|
/// `ShardedSafeTensors::get` uses internally.
|
||||||
|
fn slice_dim0_bytes(
|
||||||
|
view: &safetensors::tensor::TensorView<'_>,
|
||||||
|
start: usize,
|
||||||
|
len: usize,
|
||||||
|
tensor_name: &str,
|
||||||
|
region: &str,
|
||||||
|
) -> Result<Vec<u8>> {
|
||||||
|
use safetensors::slice::IndexOp;
|
||||||
|
let stop = start + len;
|
||||||
|
let iter = view.slice(start..stop).map_err(|e| {
|
||||||
|
anyhow::anyhow!("slice '{tensor_name}' region {region} ({start}..{stop}): {e:?}")
|
||||||
|
})?;
|
||||||
|
Ok(iter.into_iter().flatten().copied().collect())
|
||||||
|
}
|
||||||
795
crates/neuron/src/harness/tp/mod.rs
Normal file
795
crates/neuron/src/harness/tp/mod.rs
Normal file
@@ -0,0 +1,795 @@
|
|||||||
|
//! Tensor-parallel inference plumbing.
|
||||||
|
//!
|
||||||
|
//! The leader process (the neuron daemon proper) drives one
|
||||||
|
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
|
||||||
|
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
|
||||||
|
//! and talks to each over a newline-delimited JSON RPC channel on
|
||||||
|
//! the worker's stdin/stdout (see `rpc.rs`).
|
||||||
|
//!
|
||||||
|
//! Sub-staging:
|
||||||
|
//!
|
||||||
|
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
|
||||||
|
//! forks N workers; `ping` round-trips every worker to confirm
|
||||||
|
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
|
||||||
|
//! `NcclSanityCheck` are stubbed.
|
||||||
|
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
|
||||||
|
//! `NcclSanityCheck`. CUDA-gated.
|
||||||
|
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
||||||
|
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
||||||
|
|
||||||
|
pub mod all_reduce;
|
||||||
|
pub mod fused_load;
|
||||||
|
pub mod nccl_state;
|
||||||
|
pub mod rpc;
|
||||||
|
pub mod tp_linear;
|
||||||
|
pub mod tp_qwen3;
|
||||||
|
pub mod tp_qwen3_5;
|
||||||
|
pub mod worker;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::Stdio;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
|
||||||
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||||
|
|
||||||
|
use rpc::{WorkerRequest, WorkerResponse};
|
||||||
|
|
||||||
|
/// Leader-side handle for any TP-loaded model. The pool's
|
||||||
|
/// `load_dense_shard` dispatches on `config.json#/model_type` to build
|
||||||
|
/// the right variant; downstream callers (the harness's
|
||||||
|
/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`,
|
||||||
|
/// `unload_model`) all hold this enum and let the variant dispatch
|
||||||
|
/// determine the concrete forward.
|
||||||
|
///
|
||||||
|
/// Variants gated on `cuda` because the underlying TP models hold
|
||||||
|
/// `Arc<cudarc::nccl::Comm>` references — irrelevant on CPU builds.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub enum TpLeaderModel {
|
||||||
|
Qwen3(tp_qwen3::TpQwen3ForCausalLM),
|
||||||
|
Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
impl TpLeaderModel {
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
input: &candle_core::Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
TpLeaderModel::Qwen3(m) => m.forward(input, offset),
|
||||||
|
TpLeaderModel::Qwen3_5(m) => m.forward(input, offset),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
match self {
|
||||||
|
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
|
||||||
|
TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &candle_core::Device {
|
||||||
|
match self {
|
||||||
|
TpLeaderModel::Qwen3(m) => m.device(),
|
||||||
|
TpLeaderModel::Qwen3_5(m) => m.device(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One worker subprocess plus its bidirectional stdio handles.
|
||||||
|
struct Worker {
|
||||||
|
rank: u32,
|
||||||
|
/// Captured so the leader can log "spawned rank N on device M" and
|
||||||
|
/// future stages can re-issue Init after a CUDA reset. Unused in
|
||||||
|
/// the Stage 7a-i RPC paths themselves.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
cuda_device: u32,
|
||||||
|
child: Child,
|
||||||
|
stdin: ChildStdin,
|
||||||
|
stdout: Lines<BufReader<ChildStdout>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Worker {
|
||||||
|
/// Send a request and wait for the response. Used for sequenced
|
||||||
|
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
|
||||||
|
/// overlap the worker's execution with the leader's.
|
||||||
|
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
|
||||||
|
self.send_only(req).await?;
|
||||||
|
self.recv_only().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write a request without awaiting its response. Pair with
|
||||||
|
/// `recv_only` from the caller when leader and worker need to do
|
||||||
|
/// work concurrently — e.g. during `Init`, where the leader
|
||||||
|
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
|
||||||
|
/// workers, then collects `InitOk` after NCCL completes.
|
||||||
|
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
|
||||||
|
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
|
||||||
|
line.push('\n');
|
||||||
|
self.stdin
|
||||||
|
.write_all(line.as_bytes())
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("write request to rank {}", self.rank))?;
|
||||||
|
self.stdin
|
||||||
|
.flush()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn recv_only(&mut self) -> Result<WorkerResponse> {
|
||||||
|
let reply = self
|
||||||
|
.stdout
|
||||||
|
.next_line()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("read reply from rank {}", self.rank))?
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
|
||||||
|
serde_json::from_str(&reply)
|
||||||
|
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drain one response from every worker, classifying each via the
|
||||||
|
/// supplied checker. Always reads from every worker — even if some
|
||||||
|
/// fail — so the next call's recv doesn't pick up stale responses
|
||||||
|
/// from this one (pipe-poisoning was the cause of the
|
||||||
|
/// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class
|
||||||
|
/// of bugs).
|
||||||
|
///
|
||||||
|
/// Returns a vector of `rank N: detail` strings for any worker that
|
||||||
|
/// errored, expected-mismatched, or failed to respond. Caller decides
|
||||||
|
/// how to combine these with the leader's outcome.
|
||||||
|
async fn drain_workers(
|
||||||
|
workers: &mut [Worker],
|
||||||
|
mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>,
|
||||||
|
) -> Vec<String> {
|
||||||
|
let mut errs = Vec::new();
|
||||||
|
for w in workers {
|
||||||
|
match w.recv_only().await {
|
||||||
|
Ok(resp) => {
|
||||||
|
if let Err(detail) = check(resp) {
|
||||||
|
errs.push(format!("rank {} {detail}", w.rank));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errs
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Combine a leader's `Result<Result<T>>` (the typical
|
||||||
|
/// `spawn_blocking → JoinHandle<Result<T>>` shape) with the worker
|
||||||
|
/// drain results into a single `Result<T>`. Leader failures take
|
||||||
|
/// precedence in the error message but worker errors get appended so
|
||||||
|
/// the operator sees both halves.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn combine_leader_workers<T>(
|
||||||
|
leader: Result<Result<T>>,
|
||||||
|
worker_errors: Vec<String>,
|
||||||
|
op: &str,
|
||||||
|
) -> Result<T> {
|
||||||
|
match leader {
|
||||||
|
Ok(Ok(value)) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Ok(value)
|
||||||
|
} else {
|
||||||
|
anyhow::bail!(
|
||||||
|
"{op}: leader succeeded but workers failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Err(e.context(format!("{op}: leader forward failed")))
|
||||||
|
} else {
|
||||||
|
Err(e.context(format!(
|
||||||
|
"{op}: leader forward failed and workers also failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(panic_err) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Err(panic_err)
|
||||||
|
} else {
|
||||||
|
Err(panic_err.context(format!(
|
||||||
|
"{op}: leader task panicked and workers failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A live pool of worker subprocesses. Owns the `Child` handles so
|
||||||
|
/// dropping the pool kills the children; explicit `shutdown()` is
|
||||||
|
/// the graceful path.
|
||||||
|
pub struct WorkerPool {
|
||||||
|
world_size: u32,
|
||||||
|
workers: Vec<Worker>,
|
||||||
|
/// Path to the neuron binary used to launch workers.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
exe: PathBuf,
|
||||||
|
/// The leader's per-device CUDA worker thread. Phase 3 moved the
|
||||||
|
/// leader's `NcclState` (rank-0 NCCL Comm) into this thread, so
|
||||||
|
/// every NCCL op (init, sanity, all_reduce inside forward) issues
|
||||||
|
/// from one OS thread for the daemon's lifetime. The handle is
|
||||||
|
/// also used by `load_dense_shard` to clone the leader's
|
||||||
|
/// `Arc<Comm>` for the row-parallel layers' AllReduce ops; in
|
||||||
|
/// Phase 4 the load itself moves onto the worker and that bridge
|
||||||
|
/// goes away.
|
||||||
|
pub(crate) leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerPool {
|
||||||
|
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
||||||
|
/// leader (in-process) and is *not* spawned here — the leader
|
||||||
|
/// holds rank 0's NCCL Comm and shard in its own address space.
|
||||||
|
///
|
||||||
|
/// `binary` is the path to the neuron executable to run for each
|
||||||
|
/// worker (production passes `/proc/self/exe`; tests pass the
|
||||||
|
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
|
||||||
|
/// `cuda_devices` is one entry per rank including rank 0. Worker
|
||||||
|
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
|
||||||
|
pub async fn spawn(
|
||||||
|
binary: &Path,
|
||||||
|
world_size: u32,
|
||||||
|
cuda_devices: &[u32],
|
||||||
|
leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if world_size < 2 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"WorkerPool::spawn called with world_size={world_size}; \
|
||||||
|
use the single-process path for world_size < 2"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if cuda_devices.len() as u32 != world_size {
|
||||||
|
anyhow::bail!(
|
||||||
|
"expected {world_size} cuda_devices entries, got {}",
|
||||||
|
cuda_devices.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let exe = binary.to_path_buf();
|
||||||
|
|
||||||
|
let mut workers = Vec::with_capacity(world_size as usize - 1);
|
||||||
|
// Rank 0 stays in-process. Spawn ranks 1..world_size.
|
||||||
|
for rank in 1..world_size {
|
||||||
|
let cuda_device = cuda_devices[rank as usize];
|
||||||
|
let mut cmd = Command::new(&exe);
|
||||||
|
cmd.arg("--worker")
|
||||||
|
.arg("--rank")
|
||||||
|
.arg(rank.to_string())
|
||||||
|
.arg("--tp-size")
|
||||||
|
.arg(world_size.to_string())
|
||||||
|
.arg("--cuda-device")
|
||||||
|
.arg(cuda_device.to_string())
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
// Inherit stderr so worker tracing surfaces alongside
|
||||||
|
// the leader's journalctl stream.
|
||||||
|
.stderr(Stdio::inherit())
|
||||||
|
.kill_on_drop(true);
|
||||||
|
|
||||||
|
let mut child = cmd
|
||||||
|
.spawn()
|
||||||
|
.with_context(|| format!("spawn worker rank {rank}"))?;
|
||||||
|
let stdin = child
|
||||||
|
.stdin
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
|
||||||
|
let stdout = child
|
||||||
|
.stdout
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
|
||||||
|
let stdout = BufReader::new(stdout).lines();
|
||||||
|
|
||||||
|
workers.push(Worker {
|
||||||
|
rank,
|
||||||
|
cuda_device,
|
||||||
|
child,
|
||||||
|
stdin,
|
||||||
|
stdout,
|
||||||
|
});
|
||||||
|
tracing::info!(rank, cuda_device, "spawned tp worker");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
world_size,
|
||||||
|
workers,
|
||||||
|
exe,
|
||||||
|
leader_worker,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Establish the NCCL communicator across the leader (rank 0) and
|
||||||
|
/// every worker subprocess. Rendezvous is via a freshly-generated
|
||||||
|
/// `Id` broadcast over the RPC stream; the actual handshake blocks
|
||||||
|
/// inside `Comm::from_rank` until all `world_size` ranks check in.
|
||||||
|
///
|
||||||
|
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
|
||||||
|
/// to — typically the first entry of the `cuda_devices` slice
|
||||||
|
/// originally passed to `spawn()`.
|
||||||
|
///
|
||||||
|
/// On the non-cuda build this immediately fails because the leader
|
||||||
|
/// can't generate an `Id` without libnccl. The same call works in
|
||||||
|
/// the worker path (returning a no-cuda error response) so the
|
||||||
|
/// failure surface is uniform.
|
||||||
|
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
|
||||||
|
let comm_id = nccl_state::generate_comm_id_hex()
|
||||||
|
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
|
||||||
|
|
||||||
|
// 1. Write Init to every worker's stdin without awaiting the
|
||||||
|
// response. Workers will parse and call Comm::from_rank
|
||||||
|
// concurrently with the leader below.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: comm_id.clone(),
|
||||||
|
};
|
||||||
|
w.send_only(&req).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader rank 0 calls Comm::from_rank on its own device.
|
||||||
|
// Phase 3 moved this from spawn_blocking onto the leader's
|
||||||
|
// device worker thread (`Job::NcclInit`); the underlying
|
||||||
|
// `Comm` now lives on the same OS thread for its entire
|
||||||
|
// lifetime, including every later `Comm::all_reduce` issued
|
||||||
|
// by the row-parallel layers during forward.
|
||||||
|
//
|
||||||
|
// NCCL's init blocks until every rank has called in — the
|
||||||
|
// subprocess workers above and the leader's device worker
|
||||||
|
// here. The Job's reply unblocks when the leader's
|
||||||
|
// Comm::from_rank returns.
|
||||||
|
let leader_cfg = worker::WorkerConfig {
|
||||||
|
rank: 0,
|
||||||
|
world_size: self.world_size,
|
||||||
|
cuda_device: leader_cuda_device,
|
||||||
|
};
|
||||||
|
let leader_resp = self
|
||||||
|
.leader_worker
|
||||||
|
.nccl_init(leader_cfg, comm_id.clone())
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("leader NCCL init via device worker: {e}"))?;
|
||||||
|
match leader_resp {
|
||||||
|
rpc::WorkerResponse::InitOk => {}
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Read InitOk from each worker. By now every worker has
|
||||||
|
// completed its Comm::from_rank call (NCCL released them
|
||||||
|
// when the leader joined the handshake) and is writing its
|
||||||
|
// response.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match &resp {
|
||||||
|
rpc::WorkerResponse::InitOk => {}
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} init: expected InitOk, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::info!(
|
||||||
|
world_size = self.world_size,
|
||||||
|
"NCCL communicator established across all ranks"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate the NCCL communicator: every rank `all_reduce`s a
|
||||||
|
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
|
||||||
|
/// `world_size`. Confirms the handshake is live, not just
|
||||||
|
/// configured.
|
||||||
|
///
|
||||||
|
/// Must be called after `init_nccl()`; before that the leader has
|
||||||
|
/// no Comm and the workers reply with `nccl_not_initialised`.
|
||||||
|
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
|
||||||
|
// 1. Trigger the all_reduce on every worker (write-only).
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader's own all_reduce, on its device worker thread.
|
||||||
|
// NCCL operations block until every rank participates;
|
||||||
|
// Job::NcclSanity returns once the leader's side completes
|
||||||
|
// (which happens when every subprocess worker reaches its
|
||||||
|
// all_reduce call too).
|
||||||
|
let leader_resp = self
|
||||||
|
.leader_worker
|
||||||
|
.nccl_sanity()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("leader NCCL sanity via device worker: {e}"))?;
|
||||||
|
|
||||||
|
let expected = self.world_size;
|
||||||
|
let leader_sum = match leader_resp {
|
||||||
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
|
||||||
|
};
|
||||||
|
if leader_sum != expected {
|
||||||
|
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Read sanity result from each worker. All must match
|
||||||
|
// world_size — anything else means the collective didn't
|
||||||
|
// complete consistently across ranks.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
rpc::WorkerResponse::NcclSanityResult { observed_sum }
|
||||||
|
if observed_sum == expected => {}
|
||||||
|
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||||
|
anyhow::bail!(
|
||||||
|
"worker rank {} observed_sum={observed_sum}, expected {expected}",
|
||||||
|
w.rank
|
||||||
|
);
|
||||||
|
}
|
||||||
|
rpc::WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::info!(
|
||||||
|
world_size = expected,
|
||||||
|
"NCCL sanity check OK across all ranks"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ping every worker and return their Pong payloads in rank order.
|
||||||
|
/// Useful right after `spawn` to confirm the lifecycle plumbing is
|
||||||
|
/// intact before kicking off any heavier work.
|
||||||
|
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
|
||||||
|
let mut out = Vec::with_capacity(self.workers.len());
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.request(&WorkerRequest::Ping).await?;
|
||||||
|
match &resp {
|
||||||
|
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
|
||||||
|
WorkerResponse::Pong { rank, .. } => {
|
||||||
|
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
|
||||||
|
}
|
||||||
|
out.push(resp);
|
||||||
|
}
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load this rank's shard of a dense Qwen3 model on every rank.
|
||||||
|
///
|
||||||
|
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
|
||||||
|
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
|
||||||
|
/// shards in their own address spaces and confirm via
|
||||||
|
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
|
||||||
|
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
|
||||||
|
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
|
||||||
|
/// of the full model (plus the replicated embedding/norm/lm_head).
|
||||||
|
///
|
||||||
|
/// `leader_device` is the candle `Device` the leader's shard lives
|
||||||
|
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
|
||||||
|
/// the same index passed to `init_nccl`. `dtype` is the on-device
|
||||||
|
/// element type; bf16 is the canonical Qwen3 distribution dtype.
|
||||||
|
///
|
||||||
|
/// `init_nccl` must have completed first. Bails if the leader's
|
||||||
|
/// NCCL comm isn't set up yet.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn load_dense_shard(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
config_json: &str,
|
||||||
|
safetensors_paths: &[std::path::PathBuf],
|
||||||
|
_leader_device: &candle_core::Device,
|
||||||
|
dtype: candle_core::DType,
|
||||||
|
quant: Option<String>,
|
||||||
|
) -> Result<super::device_worker::TpHandle> {
|
||||||
|
let world_size = self.world_size;
|
||||||
|
let safetensors_str: Vec<String> = safetensors_paths
|
||||||
|
.iter()
|
||||||
|
.map(|p| p.to_string_lossy().into_owned())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 1. Fan out the LoadDenseShard request to every subprocess
|
||||||
|
// worker without awaiting their replies — they'll build
|
||||||
|
// their shards in parallel with the leader below.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::LoadDenseShard {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
config_json: config_json.to_string(),
|
||||||
|
safetensors_paths: safetensors_str.clone(),
|
||||||
|
quant: quant.clone(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Build rank 0's shard on the leader's device worker
|
||||||
|
// thread. Phase 4 moved the load itself onto the worker —
|
||||||
|
// the dispatch handler reads `state.nccl.comm()` directly
|
||||||
|
// so the leader's `Arc<Comm>` clones embedded in the
|
||||||
|
// row-parallel layers are constructed and used on the same
|
||||||
|
// OS thread for the model's entire lifetime. No
|
||||||
|
// spawn_blocking, no SendComm bridge.
|
||||||
|
let handle = self
|
||||||
|
.leader_worker
|
||||||
|
.tp_load_shard(
|
||||||
|
model_id.to_string(),
|
||||||
|
config_json.to_string(),
|
||||||
|
safetensors_paths.to_vec(),
|
||||||
|
dtype,
|
||||||
|
quant.clone(),
|
||||||
|
world_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("leader TP shard load via device worker: {e}"))?;
|
||||||
|
|
||||||
|
// 3. Collect worker confirmations. Anything other than
|
||||||
|
// LoadDenseShardOk aborts the whole load — the leader's
|
||||||
|
// already-inserted shard would leak in the worker slab
|
||||||
|
// until the daemon restarts; an explicit DropTp would be
|
||||||
|
// cleaner but the failure here is rare and the operator's
|
||||||
|
// next step is to restart anyway.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::LoadDenseShardOk => {}
|
||||||
|
WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run one forward step across every rank. The leader's forward
|
||||||
|
/// runs on the device worker thread via `Job::TpForwardLogits` and
|
||||||
|
/// returns CPU-side `[vocab]` logits as `Vec<f32>`; the async
|
||||||
|
/// caller wraps them in a CPU tensor for `apply_repeat_penalty` +
|
||||||
|
/// sampling without holding a device-resident tensor on a tokio
|
||||||
|
/// thread.
|
||||||
|
///
|
||||||
|
/// Subprocess workers run their own forwards in parallel (the
|
||||||
|
/// AllReduce CustomOps inside row-parallel layers are what let
|
||||||
|
/// the leader's collective complete) and reply with
|
||||||
|
/// `GenerateStepOk` over the RPC stream — they do not ship logits.
|
||||||
|
///
|
||||||
|
/// `tokens` is the input for this step (prompt for prefill, the
|
||||||
|
/// previously-sampled token for decode). `offset` is the KV-cache
|
||||||
|
/// position before this step.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn generate_step(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
leader_handle: super::device_worker::TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<Vec<f32>> {
|
||||||
|
let step_start = std::time::Instant::now();
|
||||||
|
let tokens_len = tokens.len();
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
offset,
|
||||||
|
"WorkerPool::generate_step: fan-out"
|
||||||
|
);
|
||||||
|
// 1. Fan-out to subprocess workers.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::GenerateStep {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
tokens: tokens.clone(),
|
||||||
|
offset,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader's forward on its device worker thread. The
|
||||||
|
// AllReduce CustomOps inside the row-parallel layers block
|
||||||
|
// until every subprocess worker's forward issues the
|
||||||
|
// matching collective. Returning CPU-side `Vec<f32>` keeps
|
||||||
|
// the device tensor from escaping the worker thread —
|
||||||
|
// that's the invariant the whole refactor exists to
|
||||||
|
// preserve.
|
||||||
|
let leader_start = std::time::Instant::now();
|
||||||
|
let leader_result = self
|
||||||
|
.leader_worker
|
||||||
|
.tp_forward_logits(leader_handle, tokens, offset)
|
||||||
|
.await;
|
||||||
|
let leader_ok = leader_result.is_ok();
|
||||||
|
let leader_ms = leader_start.elapsed().as_millis();
|
||||||
|
// Surface the leader's own error at WARN before draining
|
||||||
|
// workers so the operator can correlate it with whatever the
|
||||||
|
// subprocess workers logged. Previously this was silently
|
||||||
|
// coerced to a bool.
|
||||||
|
if !leader_ok {
|
||||||
|
let detail = leader_result
|
||||||
|
.as_ref()
|
||||||
|
.err()
|
||||||
|
.map(|e| format!("{e:#}"))
|
||||||
|
.unwrap_or_default();
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
offset,
|
||||||
|
leader_ms,
|
||||||
|
error = %detail,
|
||||||
|
"WorkerPool::generate_step: leader forward failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
leader_ms,
|
||||||
|
leader_ok,
|
||||||
|
"WorkerPool::generate_step: leader forward returned"
|
||||||
|
);
|
||||||
|
|
||||||
|
// 3. ALWAYS drain worker responses, regardless of whether the
|
||||||
|
// leader succeeded. Skipping this on the leader's error
|
||||||
|
// path leaves stale GenerateStepOk replies in the worker
|
||||||
|
// pipes that poison the NEXT request's recv (was seeing
|
||||||
|
// "ClearKvCache: expected KvCacheCleared, got
|
||||||
|
// GenerateStepOk" the call after any forward-time failure).
|
||||||
|
let drain_start = std::time::Instant::now();
|
||||||
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||||
|
WorkerResponse::GenerateStepOk => Ok(()),
|
||||||
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||||
|
other => Err(format!("expected GenerateStepOk, got {other:?}")),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
drain_ms = drain_start.elapsed().as_millis(),
|
||||||
|
errors = worker_errors.len(),
|
||||||
|
total_ms = step_start.elapsed().as_millis(),
|
||||||
|
"WorkerPool::generate_step: workers drained"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Combine the leader's Result + the workers' string-error
|
||||||
|
// list. Phase 3 inlines this because the upstream
|
||||||
|
// `combine_leader_workers` expects the spawn_blocking-shaped
|
||||||
|
// `Result<Result<T>>`; the new device-worker path produces a
|
||||||
|
// single `Result<T, WorkerError>` instead.
|
||||||
|
match leader_result {
|
||||||
|
Ok(values) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Ok(values)
|
||||||
|
} else {
|
||||||
|
anyhow::bail!(
|
||||||
|
"GenerateStep: leader succeeded but workers failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Err(anyhow::Error::new(e).context("GenerateStep: leader forward failed"))
|
||||||
|
} else {
|
||||||
|
Err(anyhow::Error::new(e).context(format!(
|
||||||
|
"GenerateStep: leader forward failed and workers also failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the KV cache for `model_id` on every rank. Called at the
|
||||||
|
/// start of every inference so a fresh request doesn't attend over
|
||||||
|
/// the previous one's tokens.
|
||||||
|
pub async fn clear_kv_cache(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
#[cfg(feature = "cuda")] leader_handle: super::device_worker::TpHandle,
|
||||||
|
) -> Result<()> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::ClearKvCache {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
// Leader-side clear on the device worker thread —
|
||||||
|
// `TpLeaderModel::clear_kv_cache` is infallible but still
|
||||||
|
// routes through Job::TpClearKv so the cache reset runs
|
||||||
|
// on the same thread that owns the model's CUDA tensors.
|
||||||
|
if let Err(e) = self.leader_worker.tp_clear_kv(leader_handle).await {
|
||||||
|
anyhow::bail!("leader TP clear_kv_cache via device worker: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Drain workers — same rationale as `generate_step`. The
|
||||||
|
// leader's clear_kv_cache is now async-via-channel but still
|
||||||
|
// returns before the drain so the workers' KvCacheCleared
|
||||||
|
// replies are processed in order.
|
||||||
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||||
|
WorkerResponse::KvCacheCleared => Ok(()),
|
||||||
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||||
|
other => Err(format!("expected KvCacheCleared, got {other:?}")),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
errors = worker_errors.len(),
|
||||||
|
"WorkerPool::clear_kv_cache: workers drained"
|
||||||
|
);
|
||||||
|
if !worker_errors.is_empty() {
|
||||||
|
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drop this model's shards on every rank. The leader's shard is
|
||||||
|
/// expected to have been dropped by the caller (its `Arc` was held
|
||||||
|
/// in the TpLoadedModel and goes away when that's removed).
|
||||||
|
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::UnloadModel {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
for w in &mut self.workers {
|
||||||
|
let resp = w.recv_only().await?;
|
||||||
|
match resp {
|
||||||
|
WorkerResponse::Unloaded => {}
|
||||||
|
WorkerResponse::Error { kind, message } => {
|
||||||
|
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
|
||||||
|
}
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
|
||||||
|
w.rank
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
||||||
|
/// children. Best-effort — individual worker failures are logged
|
||||||
|
/// but don't abort the rest of the sweep.
|
||||||
|
pub async fn shutdown(mut self) -> Result<()> {
|
||||||
|
for w in &mut self.workers {
|
||||||
|
match w.request(&WorkerRequest::Shutdown).await {
|
||||||
|
Ok(WorkerResponse::Bye) => {}
|
||||||
|
Ok(other) => tracing::warn!(
|
||||||
|
rank = w.rank,
|
||||||
|
response = ?other,
|
||||||
|
"expected Bye on shutdown"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for w in &mut self.workers {
|
||||||
|
match w.child.wait().await {
|
||||||
|
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
|
||||||
|
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn world_size(&self) -> u32 {
|
||||||
|
self.world_size
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_path(&self) -> &PathBuf {
|
||||||
|
&self.exe
|
||||||
|
}
|
||||||
|
}
|
||||||
293
crates/neuron/src/harness/tp/nccl_state.rs
Normal file
293
crates/neuron/src/harness/tp/nccl_state.rs
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
//! NCCL state held by both the worker process and the leader's pool.
|
||||||
|
//!
|
||||||
|
//! Split into its own module so the worker (`tp/worker.rs`) and the
|
||||||
|
//! leader (`tp/mod.rs`) share the same hex-encoding/decoding code and
|
||||||
|
//! the same shape of `Option<Comm>` state machine.
|
||||||
|
//!
|
||||||
|
//! When the `cuda` feature is off, `NcclState` is a zero-sized
|
||||||
|
//! placeholder that returns `Error{kind="cuda_feature_not_enabled"}`
|
||||||
|
//! from every operation. When it's on, the same struct holds the
|
||||||
|
//! actual `cudarc::nccl::Comm`.
|
||||||
|
|
||||||
|
use super::rpc::WorkerResponse;
|
||||||
|
use super::worker::WorkerConfig;
|
||||||
|
|
||||||
|
/// Encode bytes as lowercase hex. Used for ferrying NCCL `Id::internal()`
|
||||||
|
/// across the leader→worker RPC boundary inside a JSON string.
|
||||||
|
pub fn encode_hex(bytes: &[u8]) -> String {
|
||||||
|
let mut out = String::with_capacity(bytes.len() * 2);
|
||||||
|
for b in bytes {
|
||||||
|
use std::fmt::Write;
|
||||||
|
let _ = write!(out, "{b:02x}");
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode lowercase-or-uppercase hex into bytes. Errors on odd length
|
||||||
|
/// or non-hex characters; the caller bubbles those up via the RPC's
|
||||||
|
/// `Error{kind="bad_request"}` variant.
|
||||||
|
pub fn decode_hex(s: &str) -> Result<Vec<u8>, String> {
|
||||||
|
if !s.len().is_multiple_of(2) {
|
||||||
|
return Err(format!("hex string has odd length {}", s.len()));
|
||||||
|
}
|
||||||
|
(0..s.len())
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| {
|
||||||
|
u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("bad hex byte at {i}: {e}"))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub struct NcclState;
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
impl Default for NcclState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
impl NcclState {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, _cfg: WorkerConfig, _comm_id_hex: &str) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "this neuron binary was built without --features cuda; \
|
||||||
|
NCCL Init requires CUDA"
|
||||||
|
.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sanity_check(&mut self) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "NCCL sanity check requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
mod cuda_impl {
|
||||||
|
use super::*;
|
||||||
|
use cudarc::driver::CudaContext;
|
||||||
|
use cudarc::nccl::{Comm, Id, ReduceOp};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Number of bytes in NCCL's unique-id type; matches `Id::internal()`'s
|
||||||
|
/// `[c_char; 128]`. Wire-encoded as 256 lowercase hex chars.
|
||||||
|
const NCCL_ID_BYTES: usize = 128;
|
||||||
|
|
||||||
|
pub struct NcclState {
|
||||||
|
/// Wrapped in `Arc` so we can hand a clone to `TpQwen3ForCausalLM`
|
||||||
|
/// at load time (every row-parallel layer needs a reference to
|
||||||
|
/// run its trailing `AllReduce`). The `Arc` is the single source
|
||||||
|
/// of truth for the comm's lifetime — when the pool drops and
|
||||||
|
/// every layer that captured a clone drops, NCCL releases the
|
||||||
|
/// underlying `ncclComm_t`.
|
||||||
|
comm: Option<Arc<Comm>>,
|
||||||
|
/// Held alongside the Comm so the device isn't dropped
|
||||||
|
/// underneath the NCCL handle.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ctx: Option<Arc<CudaContext>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for NcclState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NcclState {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
comm: None,
|
||||||
|
ctx: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clone the comm out as an `Arc` so callers (the leader-side
|
||||||
|
/// `TpQwen3ForCausalLM::load`, or the worker's own model load)
|
||||||
|
/// can hold a reference for the lifetime of the model. Returns
|
||||||
|
/// `None` before `init` has run.
|
||||||
|
pub fn comm(&self) -> Option<Arc<Comm>> {
|
||||||
|
self.comm.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
|
||||||
|
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
|
||||||
|
/// given comm must be serialised", not "the handle must stay on the
|
||||||
|
/// thread that created it" — so it's safe to move an `Arc<Comm>`
|
||||||
|
/// across threads as long as no concurrent ops are issued. The
|
||||||
|
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
|
||||||
|
/// wrapper at the move boundary is the only thing missing.
|
||||||
|
///
|
||||||
|
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
|
||||||
|
/// by the row-parallel layers are only used from the
|
||||||
|
/// `spawn_blocking` thread driving the forward pass; concurrent
|
||||||
|
/// access from another thread would still be a bug.
|
||||||
|
pub struct SendComm(pub Arc<Comm>);
|
||||||
|
|
||||||
|
// SAFETY: see the doc-comment above; the invariant is enforced at
|
||||||
|
// the call site (pool Mutex + single spawn_blocking thread), not at
|
||||||
|
// the type level.
|
||||||
|
unsafe impl Send for SendComm {}
|
||||||
|
unsafe impl Sync for SendComm {}
|
||||||
|
|
||||||
|
impl SendComm {
|
||||||
|
pub fn into_inner(self) -> Arc<Comm> {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
||||||
|
// (libnccl-allocated state). NCCL requires that operations against
|
||||||
|
// one Comm be issued one at a time; we serialise access by storing
|
||||||
|
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
|
||||||
|
// move-safe — NCCL doesn't track the calling OS thread, only the
|
||||||
|
// stream the operations are dispatched against.
|
||||||
|
unsafe impl Send for NcclState {}
|
||||||
|
unsafe impl Sync for NcclState {}
|
||||||
|
|
||||||
|
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
|
||||||
|
/// the leader to mint the shared communicator id which is then
|
||||||
|
/// broadcast to every worker via the RPC `Init` message.
|
||||||
|
pub fn generate_comm_id_hex() -> Result<String, String> {
|
||||||
|
// NcclError lacks a Display impl in cudarc 0.19.x — surface
|
||||||
|
// via Debug throughout this module.
|
||||||
|
let id = Id::new().map_err(|e| format!("Id::new(): {e:?}"))?;
|
||||||
|
let bytes_u8: [u8; NCCL_ID_BYTES] = std::array::from_fn(|i| id.internal()[i] as u8);
|
||||||
|
Ok(encode_hex(&bytes_u8))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NcclState {
|
||||||
|
pub fn init(&mut self, cfg: WorkerConfig, comm_id_hex: &str) -> WorkerResponse {
|
||||||
|
match try_init(self, cfg, comm_id_hex) {
|
||||||
|
Ok(()) => WorkerResponse::InitOk,
|
||||||
|
Err(msg) => WorkerResponse::Error {
|
||||||
|
kind: "nccl_init_failed".into(),
|
||||||
|
message: msg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sanity_check(&mut self) -> WorkerResponse {
|
||||||
|
let Some(comm) = self.comm.as_ref() else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "nccl_not_initialised".into(),
|
||||||
|
message: "sanity_check requires Init to have completed first".into(),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
match try_sanity_check(comm.as_ref()) {
|
||||||
|
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
|
||||||
|
Err(msg) => WorkerResponse::Error {
|
||||||
|
kind: "nccl_sanity_failed".into(),
|
||||||
|
message: msg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_init(state: &mut NcclState, cfg: WorkerConfig, comm_id_hex: &str) -> Result<(), String> {
|
||||||
|
let bytes = decode_hex(comm_id_hex)?;
|
||||||
|
if bytes.len() != NCCL_ID_BYTES {
|
||||||
|
return Err(format!(
|
||||||
|
"comm_id is {} bytes, expected {NCCL_ID_BYTES}",
|
||||||
|
bytes.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let id_bytes: [std::ffi::c_char; NCCL_ID_BYTES] =
|
||||||
|
std::array::from_fn(|i| bytes[i] as std::ffi::c_char);
|
||||||
|
let id = Id::uninit(id_bytes);
|
||||||
|
|
||||||
|
let ctx = CudaContext::new(cfg.cuda_device as usize)
|
||||||
|
.map_err(|e| format!("CudaContext::new({}) failed: {e}", cfg.cuda_device))?;
|
||||||
|
let stream = ctx.default_stream();
|
||||||
|
let comm = Comm::from_rank(stream, cfg.rank as usize, cfg.world_size as usize, id)
|
||||||
|
.map_err(|e| {
|
||||||
|
format!(
|
||||||
|
"Comm::from_rank(rank={}, world={}) failed: {e:?}",
|
||||||
|
cfg.rank, cfg.world_size
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
state.ctx = Some(ctx);
|
||||||
|
// `Comm` is !Send + !Sync at the type level because it wraps a
|
||||||
|
// raw `ncclComm_t`. The `Arc` is fine in practice — we
|
||||||
|
// serialise operations through the pool's outer Mutex and the
|
||||||
|
// SendComm wrapper at thread-crossing boundaries enforces this
|
||||||
|
// at every move site. clippy's `arc_with_non_send_sync` lint
|
||||||
|
// can't see that invariant; allow once at the canonical
|
||||||
|
// construction site.
|
||||||
|
#[allow(clippy::arc_with_non_send_sync)]
|
||||||
|
{
|
||||||
|
state.comm = Some(Arc::new(comm));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_sanity_check(comm: &Comm) -> Result<u32, String> {
|
||||||
|
let stream = comm.stream().clone();
|
||||||
|
let input = stream
|
||||||
|
.clone_htod(&[1u32])
|
||||||
|
.map_err(|e| format!("htod sentinel: {e}"))?;
|
||||||
|
let mut output = stream
|
||||||
|
.alloc_zeros::<u32>(1)
|
||||||
|
.map_err(|e| format!("alloc output: {e}"))?;
|
||||||
|
// cudarc::nccl::NcclError doesn't impl Display in 0.19.x —
|
||||||
|
// surface via Debug so we still see the variant + ncclResult
|
||||||
|
// code instead of a generic "{e}" failure.
|
||||||
|
comm.all_reduce(&input, &mut output, &ReduceOp::Sum)
|
||||||
|
.map_err(|e| format!("all_reduce: {e:?}"))?;
|
||||||
|
let result = stream
|
||||||
|
.clone_dtoh(&output)
|
||||||
|
.map_err(|e| format!("dtoh result: {e}"))?;
|
||||||
|
Ok(result[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub use cuda_impl::{NcclState, SendComm, generate_comm_id_hex};
|
||||||
|
|
||||||
|
/// Non-cuda stub for the leader: returns a clear marker error rather
|
||||||
|
/// than letting `init_nccl` succeed vacuously.
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn generate_comm_id_hex() -> Result<String, String> {
|
||||||
|
Err("cuda_feature_not_enabled: build with --features cuda".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_roundtrip() {
|
||||||
|
let original: Vec<u8> = (0u8..=255).collect();
|
||||||
|
let encoded = encode_hex(&original);
|
||||||
|
assert_eq!(encoded.len(), 512);
|
||||||
|
let decoded = decode_hex(&encoded).expect("decode");
|
||||||
|
assert_eq!(decoded, original);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_decode_rejects_odd_length() {
|
||||||
|
assert!(decode_hex("a").is_err());
|
||||||
|
assert!(decode_hex("abc").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_decode_rejects_non_hex() {
|
||||||
|
assert!(decode_hex("zz").is_err());
|
||||||
|
assert!(decode_hex("ab_d").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hex_encode_is_lowercase_padded() {
|
||||||
|
assert_eq!(encode_hex(&[0x0a, 0xff]), "0aff");
|
||||||
|
}
|
||||||
|
}
|
||||||
257
crates/neuron/src/harness/tp/rpc.rs
Normal file
257
crates/neuron/src/harness/tp/rpc.rs
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
//! Wire protocol between the neuron leader process and its
|
||||||
|
//! `--worker` subprocesses.
|
||||||
|
//!
|
||||||
|
//! Every frame is one newline-delimited JSON object on the worker's
|
||||||
|
//! stdin (request) or stdout (response). Both directions are tagged
|
||||||
|
//! sum types from the start so new ops in Stage 7b/7c slot in without
|
||||||
|
//! breaking compatibility — no "14 message types and a version field"
|
||||||
|
//! drift later. Adding a new variant is the canonical way to evolve
|
||||||
|
//! the protocol; existing peers that don't recognise an op return
|
||||||
|
//! `WorkerResponse::Error { kind: "unknown_op", .. }`.
|
||||||
|
//!
|
||||||
|
//! The serialised shape uses `tag = "op"` so a request looks like:
|
||||||
|
//! {"op":"ping"}
|
||||||
|
//! {"op":"init","comm_id":"a1b2..."}
|
||||||
|
//! and a response:
|
||||||
|
//! {"op":"pong","rank":0,"world_size":2,"cuda_device":0}
|
||||||
|
//! {"op":"error","kind":"nccl_init_failed","message":"..."}
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Leader → worker. Worker handles one at a time; replies with exactly
|
||||||
|
/// one `WorkerResponse` per request.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "op", rename_all = "snake_case")]
|
||||||
|
pub enum WorkerRequest {
|
||||||
|
/// Liveness probe. Worker replies with `Pong` containing its own
|
||||||
|
/// identity. Used by the leader to confirm the subprocess is up
|
||||||
|
/// and ready before kicking off any heavier work.
|
||||||
|
Ping,
|
||||||
|
|
||||||
|
/// One-shot NCCL communicator setup. The leader generates the
|
||||||
|
/// `comm_id` once (rank 0 of NCCL), broadcasts it to every worker
|
||||||
|
/// via this message, then every rank (leader included) calls
|
||||||
|
/// `Comm::from_rank` with the same id — NCCL blocks until all
|
||||||
|
/// `world_size` ranks check in. The hex-encoded bytes are the
|
||||||
|
/// canonical `cudarc::nccl::Id::internal()` content.
|
||||||
|
Init {
|
||||||
|
/// Hex-encoded NCCL id bytes (128 bytes → 256 hex chars).
|
||||||
|
comm_id: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Sanity check: after Init, every rank runs an `all_reduce` over
|
||||||
|
/// a sentinel value (`1u32`). The expected sum is `world_size`.
|
||||||
|
/// Worker replies with the observed value so the leader can verify
|
||||||
|
/// the NCCL handshake is genuinely live, not just configured.
|
||||||
|
NcclSanityCheck,
|
||||||
|
|
||||||
|
/// Load this rank's shard of a dense Qwen3 model from mmaped
|
||||||
|
/// safetensors. The same `safetensors_paths` list is sent to every
|
||||||
|
/// rank — the ShardedVarBuilder reads only the rank-local slice of
|
||||||
|
/// each tensor at materialisation time, so the worker's VRAM
|
||||||
|
/// footprint is `1 / world_size` of the full model (plus replicated
|
||||||
|
/// embedding/norm/lm_head).
|
||||||
|
LoadDenseShard {
|
||||||
|
/// Caller-supplied id for later `GenerateStep` / `UnloadModel`
|
||||||
|
/// lookups. Typically the HF model id verbatim.
|
||||||
|
model_id: String,
|
||||||
|
/// JSON-serialised `candle_transformers::models::qwen3::Config`
|
||||||
|
/// — the same blob the leader parsed from the HF cache's
|
||||||
|
/// `config.json`. Threaded through verbatim so the worker uses
|
||||||
|
/// identical hyperparameters.
|
||||||
|
config_json: String,
|
||||||
|
/// Absolute paths the worker should mmap. The same set on every
|
||||||
|
/// rank; ShardedVarBuilder slices into them per rank.
|
||||||
|
safetensors_paths: Vec<String>,
|
||||||
|
/// Optional in-situ quantization dtype (e.g. "q5k", "q8_0",
|
||||||
|
/// "q6k"). When set, each linear-layer weight is quantized
|
||||||
|
/// at load time to the named ggml format — saves ~3-5x vs
|
||||||
|
/// bf16/f16 at the cost of some accuracy. `None` keeps the
|
||||||
|
/// weights in the on-disk dtype (typically bf16).
|
||||||
|
#[serde(default)]
|
||||||
|
quant: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Run one forward step on this rank's loaded model. The worker
|
||||||
|
/// reaches into its NCCL Comm for the row-parallel `AllReduce`s
|
||||||
|
/// inside the model — and so blocks on every other rank issuing the
|
||||||
|
/// same op. The leader does *not* receive logits back over RPC; it
|
||||||
|
/// runs its own rank-0 forward in parallel and uses its own logits
|
||||||
|
/// for sampling.
|
||||||
|
GenerateStep {
|
||||||
|
model_id: String,
|
||||||
|
/// Input token ids for this step. For prefill, the whole prompt;
|
||||||
|
/// for decode, a single token. Identical on every rank.
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
/// KV cache offset (count of tokens already in the cache before
|
||||||
|
/// this step).
|
||||||
|
offset: usize,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Reset the KV cache for this model on this rank. Sent at the
|
||||||
|
/// start of every inference so a fresh request doesn't accidentally
|
||||||
|
/// attend over the previous one's tokens.
|
||||||
|
ClearKvCache { model_id: String },
|
||||||
|
|
||||||
|
/// Drop this rank's shard for the given model. Releases the VRAM
|
||||||
|
/// the shard's weights occupied; subsequent `GenerateStep` calls
|
||||||
|
/// against the same `model_id` return an `Error`.
|
||||||
|
UnloadModel { model_id: String },
|
||||||
|
|
||||||
|
/// Worker should release resources and exit. Worker replies `Bye`
|
||||||
|
/// and then closes stdout / exits zero. The leader reaps the
|
||||||
|
/// child via the `tokio::process::Child` it kept.
|
||||||
|
Shutdown,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Worker → leader. Always exactly one of these per `WorkerRequest`.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "op", rename_all = "snake_case")]
|
||||||
|
pub enum WorkerResponse {
|
||||||
|
/// Reply to `Ping`. Carries enough identity for the leader to log
|
||||||
|
/// what it actually got back.
|
||||||
|
Pong {
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
cuda_device: u32,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Reply to `Init`. Empty payload — success is the absence of
|
||||||
|
/// `Error`. NCCL's internal blocking handshake means by the time
|
||||||
|
/// this comes back, every other rank has also reached
|
||||||
|
/// `Comm::from_rank`.
|
||||||
|
InitOk,
|
||||||
|
|
||||||
|
/// Reply to `NcclSanityCheck`. The observed sum after a single
|
||||||
|
/// `all_reduce(SUM, 1u32)` across all ranks. The leader checks
|
||||||
|
/// this matches `world_size`.
|
||||||
|
NcclSanityResult { observed_sum: u32 },
|
||||||
|
|
||||||
|
/// Reply to `LoadDenseShard`. Empty payload — success is the
|
||||||
|
/// absence of `Error`. By the time this comes back, the rank's
|
||||||
|
/// `TpQwen3ForCausalLM` is constructed in memory and ready for
|
||||||
|
/// `GenerateStep`.
|
||||||
|
LoadDenseShardOk,
|
||||||
|
|
||||||
|
/// Reply to `GenerateStep`. Empty payload — workers don't ship
|
||||||
|
/// logits over the wire. The leader uses its own rank-0 logits;
|
||||||
|
/// workers only need to confirm the collective completed.
|
||||||
|
GenerateStepOk,
|
||||||
|
|
||||||
|
/// Reply to `ClearKvCache`. Empty payload.
|
||||||
|
KvCacheCleared,
|
||||||
|
|
||||||
|
/// Reply to `UnloadModel`. Empty payload. The named model is no
|
||||||
|
/// longer present on this rank.
|
||||||
|
Unloaded,
|
||||||
|
|
||||||
|
/// Reply to `Shutdown`. Worker exits immediately after writing this.
|
||||||
|
Bye,
|
||||||
|
|
||||||
|
/// Any request can produce this instead of its dedicated success
|
||||||
|
/// variant. `kind` is a machine-readable category so the leader
|
||||||
|
/// can branch on failure mode without string-matching `message`.
|
||||||
|
Error {
|
||||||
|
/// Short tag — `nccl_init_failed`, `unknown_op`, etc.
|
||||||
|
kind: String,
|
||||||
|
/// Human-readable detail for logs.
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn roundtrip<T>(value: &T) -> T
|
||||||
|
where
|
||||||
|
T: Serialize + for<'de> Deserialize<'de>,
|
||||||
|
{
|
||||||
|
serde_json::from_str(&serde_json::to_string(value).expect("serialise"))
|
||||||
|
.expect("deserialise")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_ping_round_trip() {
|
||||||
|
let req = WorkerRequest::Ping;
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert_eq!(wire, r#"{"op":"ping"}"#);
|
||||||
|
match roundtrip(&req) {
|
||||||
|
WorkerRequest::Ping => {}
|
||||||
|
other => panic!("expected Ping, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_init_carries_hex_id() {
|
||||||
|
let req = WorkerRequest::Init {
|
||||||
|
comm_id: "deadbeef".into(),
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_shutdown_round_trip() {
|
||||||
|
assert_eq!(
|
||||||
|
serde_json::to_string(&WorkerRequest::Shutdown).unwrap(),
|
||||||
|
r#"{"op":"shutdown"}"#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_pong_round_trip() {
|
||||||
|
let resp = WorkerResponse::Pong {
|
||||||
|
rank: 1,
|
||||||
|
world_size: 4,
|
||||||
|
cuda_device: 1,
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&resp).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"pong""#));
|
||||||
|
assert!(wire.contains(r#""rank":1"#));
|
||||||
|
assert!(wire.contains(r#""world_size":4"#));
|
||||||
|
match roundtrip(&resp) {
|
||||||
|
WorkerResponse::Pong {
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
cuda_device,
|
||||||
|
} => {
|
||||||
|
assert_eq!(rank, 1);
|
||||||
|
assert_eq!(world_size, 4);
|
||||||
|
assert_eq!(cuda_device, 1);
|
||||||
|
}
|
||||||
|
other => panic!("expected Pong, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_error_carries_kind_and_message() {
|
||||||
|
let resp = WorkerResponse::Error {
|
||||||
|
kind: "nccl_init_failed".into(),
|
||||||
|
message: "could not bind device".into(),
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&resp).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"error""#));
|
||||||
|
assert!(wire.contains(r#""kind":"nccl_init_failed""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_sanity_result_round_trip() {
|
||||||
|
let resp = WorkerResponse::NcclSanityResult { observed_sum: 4 };
|
||||||
|
match roundtrip(&resp) {
|
||||||
|
WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||||
|
assert_eq!(observed_sum, 4);
|
||||||
|
}
|
||||||
|
other => panic!("expected NcclSanityResult, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unknown ops on the wire deserialise to an error rather than
|
||||||
|
/// silently mis-matching — confirms our `serde(tag = "op")`
|
||||||
|
/// configuration rejects unknowns instead of doing fuzzy matching.
|
||||||
|
#[test]
|
||||||
|
fn unknown_op_fails_to_parse() {
|
||||||
|
let result: Result<WorkerRequest, _> = serde_json::from_str(r#"{"op":"explode"}"#);
|
||||||
|
assert!(result.is_err(), "should reject unknown op, got {result:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
283
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
283
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
|
||||||
|
//! and `Shard` sharding hints.
|
||||||
|
//!
|
||||||
|
//! candle reads only the rank's slice of each weight tensor from
|
||||||
|
//! safetensors via `view.slice(start..stop)` — no full-tensor host
|
||||||
|
//! materialisation. That's a memory-efficiency win over hand-rolled
|
||||||
|
//! "load full + narrow" sharding (which the earlier
|
||||||
|
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
|
||||||
|
//!
|
||||||
|
//! Two layer types:
|
||||||
|
//!
|
||||||
|
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
|
||||||
|
//! local matmul. The downstream consumer either accepts a sharded
|
||||||
|
//! activation (next layer is also column-parallel) or all-gathers.
|
||||||
|
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
|
||||||
|
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
|
||||||
|
//!
|
||||||
|
//! Both assume **no bias** — every Qwen3-family weight layout we
|
||||||
|
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
|
||||||
|
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
|
||||||
|
//! support is mechanical when a future model needs it; the design
|
||||||
|
//! choice would be: column-parallel shards the bias along dim 0;
|
||||||
|
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
|
||||||
|
//! sum carries it exactly once.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::quantized::{GgmlDType, QMatMul, QTensor};
|
||||||
|
use candle_core::{Module, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use super::all_reduce::AllReduce;
|
||||||
|
|
||||||
|
/// Linear primitive that holds either a plain `Linear` (bf16/f16/f32)
|
||||||
|
/// or a quantized `QMatMul` (Q4K/Q5K/Q6K/Q8_0/etc.).
|
||||||
|
///
|
||||||
|
/// Constructed via [`MaybeQuantLinear::from_weight`] — pass `None` to
|
||||||
|
/// keep the weight in its loaded dtype (no quantization), or
|
||||||
|
/// `Some(dtype)` to quantize at load time.
|
||||||
|
///
|
||||||
|
/// On the forward path the two arms dispatch identically: `Module::forward`
|
||||||
|
/// returns an output in the caller's input dtype (f32 fallback for the
|
||||||
|
/// quantized matmul). Subsequent ops don't need to know whether the
|
||||||
|
/// layer was quantized.
|
||||||
|
pub enum MaybeQuantLinear {
|
||||||
|
Plain(Linear),
|
||||||
|
Quant(QMatMul),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MaybeQuantLinear {
|
||||||
|
/// Build a linear from a loaded weight tensor. If `quant` is set,
|
||||||
|
/// the weight is quantized in-situ and stored as a `QMatMul`;
|
||||||
|
/// otherwise it's wrapped in a plain `Linear`.
|
||||||
|
pub fn from_weight(weight: Tensor, quant: Option<GgmlDType>) -> Result<Self> {
|
||||||
|
match quant {
|
||||||
|
Some(dtype) => {
|
||||||
|
let qt = QTensor::quantize(&weight, dtype).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"QTensor::quantize to {dtype:?} for shape {:?}",
|
||||||
|
weight.shape()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let qmm = QMatMul::from_arc(Arc::new(qt))
|
||||||
|
.context("QMatMul::from_arc on freshly quantized weight")?;
|
||||||
|
Ok(Self::Quant(qmm))
|
||||||
|
}
|
||||||
|
None => Ok(Self::Plain(Linear::new(weight, None))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Above this M (the product of all input dims except the last)
|
||||||
|
/// dispatch the quantized matmul through `QMatMul::forward_via_f16`,
|
||||||
|
/// which dequantizes the weight to f16 once and runs cuBLAS GEMM.
|
||||||
|
/// At or below this M the GGUF GEMV kernel inside
|
||||||
|
/// `QMatMul::forward` wins (it operates on quantized blocks directly
|
||||||
|
/// and accumulates in registers).
|
||||||
|
///
|
||||||
|
/// Empirical: at M=30 on Qwen3.6-27B / RTX 5090, forward_via_f16 was
|
||||||
|
/// slightly *slower* than the GGUF GEMV kernel — the per-call dequant
|
||||||
|
/// cost (~30 MB f16 written to global memory per linear × ~480 calls
|
||||||
|
/// per prefill) eats the cuBLAS GEMM speedup at small M. The
|
||||||
|
/// crossover where the GEMM scaling actually beats the fixed dequant
|
||||||
|
/// tax sits well above M=8.
|
||||||
|
///
|
||||||
|
/// 64 is a conservative crossover that keeps short-prompt prefills
|
||||||
|
/// on the GGUF kernel (where the per-call cost is comparable to the
|
||||||
|
/// f16 path but the dequant tax is zero) and only activates the
|
||||||
|
/// dequant-then-GEMM path for long prefills where the GEMM size
|
||||||
|
/// makes amortising worth it. A proper fix is either a dequant
|
||||||
|
/// cache or a fused dequant+gemm cuda kernel — both larger projects.
|
||||||
|
const QUANT_PREFILL_M_THRESHOLD: usize = 64;
|
||||||
|
|
||||||
|
impl Module for MaybeQuantLinear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Plain(l) => l.forward(x),
|
||||||
|
Self::Quant(qm) => {
|
||||||
|
// Decode vs prefill split. `M` is the "rows of x" the
|
||||||
|
// matmul will iterate over — every dim except the last
|
||||||
|
// (which is in_features). For decode (`seq_len == 1`
|
||||||
|
// with batch 1) M is 1; for prefill with L>>1 it's L
|
||||||
|
// (or B*L).
|
||||||
|
let dims = x.dims();
|
||||||
|
let m: usize = dims.iter().take(dims.len() - 1).product();
|
||||||
|
|
||||||
|
if m > QUANT_PREFILL_M_THRESHOLD {
|
||||||
|
// Prefill: dequantize the weight once into f16,
|
||||||
|
// then run a real cuBLAS-backed GEMM. The cost of
|
||||||
|
// the dequant is amortised across all M tokens.
|
||||||
|
// `forward_via_f16` handles the dtype round-trip
|
||||||
|
// internally (output matches input dtype).
|
||||||
|
return qm.forward_via_f16(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode (M <= threshold): use the on-the-fly GGUF
|
||||||
|
// GEMV kernel via `QMatMul::forward`. That kernel
|
||||||
|
// requires f32 inputs (it accumulates in f32 from the
|
||||||
|
// dequantized quant blocks); cast in/out at the
|
||||||
|
// boundary.
|
||||||
|
let in_dtype = x.dtype();
|
||||||
|
let x_f32 = if in_dtype == candle_core::DType::F32 {
|
||||||
|
x.clone()
|
||||||
|
} else {
|
||||||
|
x.to_dtype(candle_core::DType::F32)?
|
||||||
|
};
|
||||||
|
let y = qm.forward(&x_f32)?;
|
||||||
|
if y.dtype() == in_dtype {
|
||||||
|
Ok(y)
|
||||||
|
} else {
|
||||||
|
y.to_dtype(in_dtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to build a [`Shard`] hint for a given dimension.
|
||||||
|
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
|
||||||
|
Shard {
|
||||||
|
dim,
|
||||||
|
rank: rank as usize,
|
||||||
|
world_size: world_size as usize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Output-dim sharded linear (column-parallel). Holds a
|
||||||
|
/// [`MaybeQuantLinear`] whose underlying weight is this rank's slice
|
||||||
|
/// of the full `[out_features, in_features]` tensor along dim 0.
|
||||||
|
pub struct ColumnParallelLinear {
|
||||||
|
inner: MaybeQuantLinear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ColumnParallelLinear {
|
||||||
|
/// Load this rank's column-parallel slice from a
|
||||||
|
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
|
||||||
|
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
|
||||||
|
///
|
||||||
|
/// Backward-compatible variant — no in-situ quantization. For
|
||||||
|
/// quantized loads, use [`Self::load_with_quant`].
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
Self::load_with_quant(vb, rank, world_size, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Like [`Self::load`] but quantizes the per-rank weight in-situ
|
||||||
|
/// when `quant` is `Some(dtype)`. Saves ~3-5x vs bf16/f16.
|
||||||
|
pub fn load_with_quant(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(0, rank, world_size))
|
||||||
|
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||||
|
.with_context(|| format!("wrap column-parallel '{}'", vb.prefix()))?;
|
||||||
|
Ok(Self { inner })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ColumnParallelLinear {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Input-dim sharded linear (row-parallel).
|
||||||
|
///
|
||||||
|
/// Holds a sharded [`MaybeQuantLinear`] plus an `AllReduce` op the
|
||||||
|
/// forward chains after the local matmul to recover the full activation.
|
||||||
|
pub struct RowParallelLinear {
|
||||||
|
inner: MaybeQuantLinear,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
all_reduce: AllReduce,
|
||||||
|
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
|
||||||
|
/// is a pair: the column output is sharded, the row input is
|
||||||
|
/// sharded, and the AllReduce gives back the full output. For
|
||||||
|
/// `world_size = 1` the AllReduce is a no-op so we skip it.
|
||||||
|
needs_reduce: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RowParallelLinear {
|
||||||
|
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
|
||||||
|
///
|
||||||
|
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
|
||||||
|
/// `AllReduce` runs against. On CPU builds the parameter is
|
||||||
|
/// elided — forward returns the partial sum, which is the *wrong*
|
||||||
|
/// answer for inference but lets us compile-check the model.
|
||||||
|
///
|
||||||
|
/// Backward-compatible variant — no in-situ quantization. For
|
||||||
|
/// quantized loads, use [`Self::load_with_quant`].
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::load_with_quant(vb, rank, world_size, comm, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Like [`Self::load`] but quantizes the per-rank weight in-situ.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load_with_quant(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||||
|
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner,
|
||||||
|
all_reduce: AllReduce::new(comm),
|
||||||
|
needs_reduce: world_size > 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
Self::load_with_quant(vb, rank, world_size, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load_with_quant(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
quant: Option<GgmlDType>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb
|
||||||
|
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||||
|
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||||
|
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||||
|
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
|
||||||
|
Ok(Self {
|
||||||
|
inner,
|
||||||
|
needs_reduce: world_size > 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for RowParallelLinear {
|
||||||
|
/// Local matmul followed by an `AllReduce` (when `cuda` and
|
||||||
|
/// `world_size > 1`). On CPU or single-rank, returns the partial
|
||||||
|
/// output directly — which is *only* correct for `world_size == 1`.
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let local = self.inner.forward(x)?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
if self.needs_reduce {
|
||||||
|
return local.apply_op1_no_bwd(&self.all_reduce);
|
||||||
|
}
|
||||||
|
let _ = self.needs_reduce;
|
||||||
|
Ok(local)
|
||||||
|
}
|
||||||
|
}
|
||||||
678
crates/neuron/src/harness/tp/tp_qwen3.rs
Normal file
678
crates/neuron/src/harness/tp/tp_qwen3.rs
Normal file
@@ -0,0 +1,678 @@
|
|||||||
|
//! Tensor-parallel Qwen3 dense model.
|
||||||
|
//!
|
||||||
|
//! Mirrors `candle_transformers::models::qwen3` structurally, but with:
|
||||||
|
//!
|
||||||
|
//! - Attention's `q_proj` / `k_proj` / `v_proj` as
|
||||||
|
//! [`ColumnParallelLinear`] (output sharded along the head dimension —
|
||||||
|
//! per-rank `num_heads = total/world_size`, ditto for kv heads).
|
||||||
|
//! - Attention's `o_proj` as [`RowParallelLinear`] (input sharded; the
|
||||||
|
//! trailing `AllReduce` recovers the full activation).
|
||||||
|
//! - MLP's `gate_proj` / `up_proj` as [`ColumnParallelLinear`] (sharded
|
||||||
|
//! along `intermediate_size`).
|
||||||
|
//! - MLP's `down_proj` as [`RowParallelLinear`].
|
||||||
|
//! - `embed_tokens`, all `RmsNorm`s, and `lm_head` **replicated** on
|
||||||
|
//! every rank. The per-rank duplicate weight is bounded and lets us
|
||||||
|
//! skip the embedding all-gather and the lm-head column-shard +
|
||||||
|
//! all-gather; both are pure latency optimisations that don't change
|
||||||
|
//! correctness.
|
||||||
|
//! - `kv_cache` holds the per-rank slice of K/V already (because they
|
||||||
|
//! came out of a column-parallel projection). No cache resharding
|
||||||
|
//! needed across ranks.
|
||||||
|
//!
|
||||||
|
//! Divisibility requirement, checked at load time:
|
||||||
|
//!
|
||||||
|
//! - `num_attention_heads % world_size == 0`
|
||||||
|
//! - `num_key_value_heads % world_size == 0`
|
||||||
|
//! - `intermediate_size % world_size == 0`
|
||||||
|
//!
|
||||||
|
//! Anything else bails — the safetensors slice would lose data otherwise.
|
||||||
|
//! This is the same divisibility-bail pattern that landed in
|
||||||
|
//! `EricLBuehler/mistral.rs` PR #2054.
|
||||||
|
//!
|
||||||
|
//! Replicated tensors (norms, embedding, lm_head) are loaded by asking
|
||||||
|
//! the `ShardedVarBuilder` for the full tensor via `vb.get(shape, name)`
|
||||||
|
//! — which defaults to `Shard { world_size: 1 }` and falls through to
|
||||||
|
//! the unsharded backend path.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, bail};
|
||||||
|
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use candle_nn::{Activation, Embedding, Linear, RmsNorm, kv_cache::ConcatKvCache};
|
||||||
|
use candle_transformers::utils::repeat_kv;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use cudarc::nccl::Comm;
|
||||||
|
|
||||||
|
use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
|
||||||
|
|
||||||
|
pub use candle_transformers::models::qwen3::Config;
|
||||||
|
|
||||||
|
/// Replicated rotary-embedding lookup. Re-implementation of the
|
||||||
|
/// `pub(crate)` candle equivalent — we can't reach into the upstream
|
||||||
|
/// type, so the inv-freq / sin / cos construction lives here.
|
||||||
|
pub(crate) struct Qwen3RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3RotaryEmbedding {
|
||||||
|
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let dim = cfg.head_dim;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let (_, _, seq_len, _) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper: load a replicated tensor by asking the ShardedVarBuilder for
|
||||||
|
/// the full tensor (world_size=1 hint falls through to SimpleBackend).
|
||||||
|
fn load_replicated<S: Into<candle_core::Shape>>(
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
shape: S,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
vb.get(shape, name)
|
||||||
|
.with_context(|| format!("load replicated '{}/{name}'", vb.prefix()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_rms_norm(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<RmsNorm> {
|
||||||
|
let weight = load_replicated(vb, size, "weight")?;
|
||||||
|
Ok(RmsNorm::new(weight, eps))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP MLP. SwiGLU = `down(silu(gate(x)) * up(x))`.
|
||||||
|
pub(crate) struct TpQwen3MLP {
|
||||||
|
gate_proj: ColumnParallelLinear,
|
||||||
|
up_proj: ColumnParallelLinear,
|
||||||
|
down_proj: RowParallelLinear,
|
||||||
|
act_fn: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3MLP {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||||
|
bail!(
|
||||||
|
"intermediate_size {} not divisible by world_size {}",
|
||||||
|
cfg.intermediate_size,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||||
|
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||||
|
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||||
|
bail!(
|
||||||
|
"intermediate_size {} not divisible by world_size {}",
|
||||||
|
cfg.intermediate_size,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||||
|
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||||
|
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TpQwen3MLP {
|
||||||
|
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||||
|
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = x.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP attention. Carries per-rank head counts and the q/k per-head
|
||||||
|
/// RmsNorms (which are replicated and operate on a flattened B*H*L
|
||||||
|
/// axis, so the same code path works irrespective of how H was split).
|
||||||
|
pub(crate) struct TpQwen3Attention {
|
||||||
|
q_proj: ColumnParallelLinear,
|
||||||
|
k_proj: ColumnParallelLinear,
|
||||||
|
v_proj: ColumnParallelLinear,
|
||||||
|
o_proj: RowParallelLinear,
|
||||||
|
q_norm: RmsNorm,
|
||||||
|
k_norm: RmsNorm,
|
||||||
|
local_num_heads: usize,
|
||||||
|
local_num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
local_hidden_size: usize,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
kv_cache: ConcatKvCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3Attention {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::load_inner(
|
||||||
|
cfg,
|
||||||
|
rotary_emb,
|
||||||
|
vb,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
comm,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::load_inner(cfg, rotary_emb, vb, rank, world_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_inner(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if cfg.use_sliding_window {
|
||||||
|
bail!("sliding window is not yet supported in the TP path");
|
||||||
|
}
|
||||||
|
if cfg.attention_bias {
|
||||||
|
bail!("attention_bias=true is not supported by ColumnParallel/RowParallelLinear yet");
|
||||||
|
}
|
||||||
|
let ws = world_size as usize;
|
||||||
|
if !cfg.num_attention_heads.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"num_attention_heads {} not divisible by world_size {}",
|
||||||
|
cfg.num_attention_heads,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !cfg.num_key_value_heads.is_multiple_of(ws) {
|
||||||
|
bail!(
|
||||||
|
"num_key_value_heads {} not divisible by world_size {}",
|
||||||
|
cfg.num_key_value_heads,
|
||||||
|
world_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let local_num_heads = cfg.num_attention_heads / ws;
|
||||||
|
let local_num_kv_heads = cfg.num_key_value_heads / ws;
|
||||||
|
let num_kv_groups = local_num_heads / local_num_kv_heads;
|
||||||
|
let local_hidden_size = head_dim * local_num_heads;
|
||||||
|
|
||||||
|
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?;
|
||||||
|
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?;
|
||||||
|
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?;
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?;
|
||||||
|
|
||||||
|
let q_norm = load_rms_norm(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
let k_norm = load_rms_norm(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
// dim=2 because we cat along the seq axis of (B, H, L, D) tensors.
|
||||||
|
let kv_cache = ConcatKvCache::new(2);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
local_num_heads,
|
||||||
|
local_num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
local_hidden_size,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
|
// 1. Projections (column-parallel → output is sharded).
|
||||||
|
let q = self.q_proj.forward(x)?;
|
||||||
|
let k = self.k_proj.forward(x)?;
|
||||||
|
let v = self.v_proj.forward(x)?;
|
||||||
|
|
||||||
|
// 2. Reshape: (B, L, H, D) → (B, H, L, D).
|
||||||
|
let q = q
|
||||||
|
.reshape((b, l, self.local_num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let k = k
|
||||||
|
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = v
|
||||||
|
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
// 3. Per-head RmsNorm (replicated weight, flat input).
|
||||||
|
let q_flat = q.flatten(0, 2)?;
|
||||||
|
let k_flat = k.flatten(0, 2)?;
|
||||||
|
let q_flat = self.q_norm.forward(&q_flat)?;
|
||||||
|
let k_flat = self.k_norm.forward(&k_flat)?;
|
||||||
|
let q = q_flat.reshape((b, self.local_num_heads, l, self.head_dim))?;
|
||||||
|
let k = k_flat.reshape((b, self.local_num_kv_heads, l, self.head_dim))?;
|
||||||
|
|
||||||
|
// 4. Rotary.
|
||||||
|
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
|
||||||
|
|
||||||
|
// 5. Accumulate KV.
|
||||||
|
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||||
|
|
||||||
|
// 6. GQA repeat_kv on the rank-local K/V.
|
||||||
|
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||||
|
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
|
// 7. Attention scores.
|
||||||
|
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||||
|
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
if let Some(m) = attn_mask {
|
||||||
|
scores = scores.broadcast_add(m)?;
|
||||||
|
}
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
let ctx = probs.matmul(&v)?;
|
||||||
|
|
||||||
|
// 8. Output projection (row-parallel → AllReduce inside).
|
||||||
|
ctx.transpose(1, 2)?
|
||||||
|
.reshape((b, l, self.local_hidden_size))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TpDecoderLayer {
|
||||||
|
self_attn: TpQwen3Attention,
|
||||||
|
mlp: TpQwen3MLP,
|
||||||
|
ln1: RmsNorm,
|
||||||
|
ln2: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpDecoderLayer {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn = TpQwen3Attention::load(
|
||||||
|
cfg,
|
||||||
|
rotary_emb,
|
||||||
|
&vb.pp("self_attn"),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
comm.clone(),
|
||||||
|
)?;
|
||||||
|
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?;
|
||||||
|
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
let ln2 = load_rms_norm(
|
||||||
|
&vb.pp("post_attention_layernorm"),
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn =
|
||||||
|
TpQwen3Attention::load(cfg, rotary_emb, &vb.pp("self_attn"), rank, world_size)?;
|
||||||
|
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?;
|
||||||
|
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
let ln2 = load_rms_norm(
|
||||||
|
&vb.pp("post_attention_layernorm"),
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let h = self.ln1.forward(x)?;
|
||||||
|
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||||
|
let x = (x + h)?;
|
||||||
|
let h2 = self.ln2.forward(&x)?;
|
||||||
|
let h2 = h2.apply(&self.mlp)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Base TP Qwen3 transformer — embedding, decoder stack, final norm.
|
||||||
|
/// The lm_head sits on top in [`TpQwen3ForCausalLM`].
|
||||||
|
pub struct TpQwen3Model {
|
||||||
|
embed_tokens: Embedding,
|
||||||
|
layers: Vec<TpDecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3Model {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||||
|
|
||||||
|
let embed_vb = vb.pp("model.embed_tokens");
|
||||||
|
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
|
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let vb_l = vb.pp("model.layers");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(TpDecoderLayer::load(
|
||||||
|
cfg,
|
||||||
|
rotary.clone(),
|
||||||
|
&vb_l.pp(i),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
comm.clone(),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||||
|
|
||||||
|
let embed_vb = vb.pp("model.embed_tokens");
|
||||||
|
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
|
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||||
|
|
||||||
|
let vb_l = vb.pp("model.layers");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(TpDecoderLayer::load(
|
||||||
|
cfg,
|
||||||
|
rotary.clone(),
|
||||||
|
&vb_l.pp(i),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_weight(&self) -> &Tensor {
|
||||||
|
self.embed_tokens.embeddings()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for l in &mut self.layers {
|
||||||
|
l.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
|
};
|
||||||
|
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TP Qwen3 with a (replicated) language-model head on top.
|
||||||
|
pub struct TpQwen3ForCausalLM {
|
||||||
|
base: TpQwen3Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TpQwen3ForCausalLM {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub fn load(
|
||||||
|
cfg: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
rank: u32,
|
||||||
|
world_size: u32,
|
||||||
|
comm: Arc<Comm>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let base = TpQwen3Model::load(cfg, vb, rank, world_size, comm)?;
|
||||||
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
|
let model = Self { base, lm_head };
|
||||||
|
log_construction_complete(cfg, rank, world_size, model.device());
|
||||||
|
Ok(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||||
|
let base = TpQwen3Model::load(cfg, vb, rank, world_size)?;
|
||||||
|
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||||
|
let model = Self { base, lm_head };
|
||||||
|
log_construction_complete(cfg, rank, world_size, model.device());
|
||||||
|
Ok(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
let hidden = self.base.forward(input, offset)?;
|
||||||
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base.clear_kv_cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &Device {
|
||||||
|
&self.base.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> DType {
|
||||||
|
self.base.dtype
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_lm_head(cfg: &Config, vb: &ShardedVarBuilder, base: &TpQwen3Model) -> Result<Linear> {
|
||||||
|
if cfg.tie_word_embeddings {
|
||||||
|
Ok(Linear::new(base.embed_weight().clone(), None))
|
||||||
|
} else {
|
||||||
|
let weight = load_replicated(
|
||||||
|
&vb.pp("lm_head"),
|
||||||
|
(cfg.vocab_size, cfg.hidden_size),
|
||||||
|
"weight",
|
||||||
|
)?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// VRAM accounting + config dump emitted at the end of
|
||||||
|
/// `TpQwen3ForCausalLM::load`. Same intent as the Qwen3-Next variant
|
||||||
|
/// in tp_qwen3_5.rs — surface the resolved hyperparameters and
|
||||||
|
/// per-rank free VRAM in one line so an operator chasing an OOM or a
|
||||||
|
/// numerical issue doesn't have to grep the per-layer load logs.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn log_construction_complete(cfg: &Config, rank: u32, world_size: u32, device: &Device) {
|
||||||
|
use candle_core::cuda::cudarc::driver::result;
|
||||||
|
use candle_core::cuda_backend::WrapErr;
|
||||||
|
let (free_mb, total_mb) = if let Device::Cuda(dev) = device {
|
||||||
|
if dev.cuda_stream().context().bind_to_thread().w().is_ok() {
|
||||||
|
match result::mem_get_info() {
|
||||||
|
Ok((free, total)) => (free / (1024 * 1024), total / (1024 * 1024)),
|
||||||
|
Err(_) => (0, 0),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(0, 0)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(0, 0)
|
||||||
|
};
|
||||||
|
// Per-rank KV cache cost at one token: K + V × bf16. Vanilla
|
||||||
|
// Qwen3 is dense attention end-to-end, so every layer
|
||||||
|
// contributes. Knowing per-token bytes lets the operator estimate
|
||||||
|
// headroom for a given prompt length before hitting an edge.
|
||||||
|
let per_rank_num_kv_heads = (cfg.num_key_value_heads / world_size as usize).max(1);
|
||||||
|
let kv_bytes_per_token_per_layer = per_rank_num_kv_heads * cfg.head_dim * 2 * 2;
|
||||||
|
let kv_bytes_per_token = kv_bytes_per_token_per_layer * cfg.num_hidden_layers;
|
||||||
|
tracing::info!(
|
||||||
|
target: "neuron::tp::load",
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
free_mb,
|
||||||
|
total_mb,
|
||||||
|
vocab_size = cfg.vocab_size,
|
||||||
|
hidden_size = cfg.hidden_size,
|
||||||
|
num_hidden_layers = cfg.num_hidden_layers,
|
||||||
|
num_attention_heads = cfg.num_attention_heads,
|
||||||
|
num_key_value_heads = cfg.num_key_value_heads,
|
||||||
|
head_dim = cfg.head_dim,
|
||||||
|
max_position_embeddings = cfg.max_position_embeddings,
|
||||||
|
per_rank_num_kv_heads,
|
||||||
|
kv_bytes_per_token,
|
||||||
|
"Qwen3 model construction complete"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn log_construction_complete(cfg: &Config, rank: u32, world_size: u32, _device: &Device) {
|
||||||
|
let per_rank_num_kv_heads = (cfg.num_key_value_heads / world_size as usize).max(1);
|
||||||
|
let kv_bytes_per_token_per_layer = per_rank_num_kv_heads * cfg.head_dim * 2 * 2;
|
||||||
|
let kv_bytes_per_token = kv_bytes_per_token_per_layer * cfg.num_hidden_layers;
|
||||||
|
tracing::info!(
|
||||||
|
target: "neuron::tp::load",
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
vocab_size = cfg.vocab_size,
|
||||||
|
hidden_size = cfg.hidden_size,
|
||||||
|
num_hidden_layers = cfg.num_hidden_layers,
|
||||||
|
num_attention_heads = cfg.num_attention_heads,
|
||||||
|
num_key_value_heads = cfg.num_key_value_heads,
|
||||||
|
head_dim = cfg.head_dim,
|
||||||
|
max_position_embeddings = cfg.max_position_embeddings,
|
||||||
|
per_rank_num_kv_heads,
|
||||||
|
kv_bytes_per_token,
|
||||||
|
"Qwen3 model construction complete"
|
||||||
|
);
|
||||||
|
}
|
||||||
1207
crates/neuron/src/harness/tp/tp_qwen3_5.rs
Normal file
1207
crates/neuron/src/harness/tp/tp_qwen3_5.rs
Normal file
File diff suppressed because it is too large
Load Diff
502
crates/neuron/src/harness/tp/worker.rs
Normal file
502
crates/neuron/src/harness/tp/worker.rs
Normal file
@@ -0,0 +1,502 @@
|
|||||||
|
//! Entry point for `neuron --worker`.
|
||||||
|
//!
|
||||||
|
//! The worker reads one newline-delimited JSON `WorkerRequest` from
|
||||||
|
//! stdin per loop iteration, dispatches synchronously, and writes
|
||||||
|
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
|
||||||
|
//! stderr so it doesn't collide with the RPC stream.
|
||||||
|
//!
|
||||||
|
//! NCCL operations (`Init`, `NcclSanityCheck`) and model lifecycle ops
|
||||||
|
//! (`LoadDenseShard`, `GenerateStep`, `ClearKvCache`, `UnloadModel`)
|
||||||
|
//! are real when built with the `cuda` feature; without it they reply
|
||||||
|
//! with `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
|
||||||
|
//! the difference between a misconfigured build and a genuine NCCL or
|
||||||
|
//! model failure.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
|
||||||
|
use super::nccl_state::NcclState;
|
||||||
|
use super::rpc::{WorkerRequest, WorkerResponse};
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use super::tp_qwen3::TpQwen3ForCausalLM;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
use super::tp_qwen3_5::TpQwen3_5ForCausalLM;
|
||||||
|
|
||||||
|
/// Worker-side discriminator over the architectures we can load via
|
||||||
|
/// `LoadDenseShard`. Mirrors `super::TpLeaderModel` on the leader
|
||||||
|
/// side — the dispatch happens on the `model_type` extracted from the
|
||||||
|
/// config JSON.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
enum WorkerModel {
|
||||||
|
Qwen3(TpQwen3ForCausalLM),
|
||||||
|
Qwen3_5(TpQwen3_5ForCausalLM),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
impl WorkerModel {
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
input: &candle_core::Tensor,
|
||||||
|
offset: usize,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
WorkerModel::Qwen3(m) => m.forward(input, offset),
|
||||||
|
WorkerModel::Qwen3_5(m) => m.forward(input, offset),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
match self {
|
||||||
|
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
|
||||||
|
WorkerModel::Qwen3_5(m) => m.clear_kv_cache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> &candle_core::Device {
|
||||||
|
match self {
|
||||||
|
WorkerModel::Qwen3(m) => m.device(),
|
||||||
|
WorkerModel::Qwen3_5(m) => m.device(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct WorkerConfig {
|
||||||
|
pub rank: u32,
|
||||||
|
pub world_size: u32,
|
||||||
|
pub cuda_device: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drive the worker RPC loop until `Shutdown` or EOF on stdin.
|
||||||
|
pub async fn run(config: WorkerConfig) -> Result<()> {
|
||||||
|
tracing::info!(
|
||||||
|
rank = config.rank,
|
||||||
|
world_size = config.world_size,
|
||||||
|
cuda_device = config.cuda_device,
|
||||||
|
"tp worker starting"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut state = WorkerState::new(config);
|
||||||
|
let stdin = tokio::io::stdin();
|
||||||
|
let mut reader = BufReader::new(stdin).lines();
|
||||||
|
let mut stdout = tokio::io::stdout();
|
||||||
|
|
||||||
|
while let Some(line) = reader.next_line().await? {
|
||||||
|
if line.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let req: WorkerRequest = match serde_json::from_str(&line) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let resp = WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse {line:?}: {e}"),
|
||||||
|
};
|
||||||
|
write_response(&mut stdout, &resp).await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = state.handle(req).await;
|
||||||
|
let is_bye = matches!(resp, WorkerResponse::Bye);
|
||||||
|
write_response(&mut stdout, &resp).await?;
|
||||||
|
if is_bye {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(rank = config.rank, "tp worker exiting");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -> Result<()> {
|
||||||
|
let mut line = serde_json::to_string(resp)?;
|
||||||
|
line.push('\n');
|
||||||
|
stdout.write_all(line.as_bytes()).await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One rank's local state. Owns the rank's NCCL communicator (via
|
||||||
|
/// `NcclState`) and the rank's shard of every loaded model.
|
||||||
|
struct WorkerState {
|
||||||
|
config: WorkerConfig,
|
||||||
|
nccl: NcclState,
|
||||||
|
/// Loaded model shards keyed by `model_id`. Each entry wraps the
|
||||||
|
/// rank's TP architecture handle (Qwen3 or Qwen3-Next) — the
|
||||||
|
/// column/row-parallel layers hold an `Arc<Comm>` cloned from
|
||||||
|
/// `nccl`. Cuda-only: the underlying types reference cudarc types
|
||||||
|
/// that don't exist without the cuda feature.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
models: HashMap<String, WorkerModel>,
|
||||||
|
/// Placeholder so the non-cuda build keeps the same field name set
|
||||||
|
/// and `WorkerState::new` reads the same on both.
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
models: HashMap<String, ()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorkerState {
|
||||||
|
fn new(config: WorkerConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
nccl: NcclState::new(),
|
||||||
|
models: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(&mut self, req: WorkerRequest) -> WorkerResponse {
|
||||||
|
match req {
|
||||||
|
WorkerRequest::Ping => WorkerResponse::Pong {
|
||||||
|
rank: self.config.rank,
|
||||||
|
world_size: self.config.world_size,
|
||||||
|
cuda_device: self.config.cuda_device,
|
||||||
|
},
|
||||||
|
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
|
||||||
|
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
|
||||||
|
WorkerRequest::LoadDenseShard {
|
||||||
|
model_id,
|
||||||
|
config_json,
|
||||||
|
safetensors_paths,
|
||||||
|
quant,
|
||||||
|
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant),
|
||||||
|
WorkerRequest::GenerateStep {
|
||||||
|
model_id,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
} => self.handle_generate_step(&model_id, tokens, offset),
|
||||||
|
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
||||||
|
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
||||||
|
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_load_dense_shard(
|
||||||
|
&mut self,
|
||||||
|
model_id: String,
|
||||||
|
config_json: String,
|
||||||
|
safetensors_paths: Vec<String>,
|
||||||
|
quant: Option<String>,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
use crate::harness::arch::qwen3_5 as qwen3_5_arch;
|
||||||
|
use candle_core::{DType, Device};
|
||||||
|
use candle_nn::var_builder::ShardedSafeTensors;
|
||||||
|
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
let quant_dtype = match parse_quant_string(quant.as_deref()) {
|
||||||
|
Ok(q) => q,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse quant: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.models.contains_key(&model_id) {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "already_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' already loaded on this rank"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
let comm = match self.nccl.comm() {
|
||||||
|
Some(c) => c,
|
||||||
|
None => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "nccl_not_initialised".into(),
|
||||||
|
message: "LoadDenseShard requires Init to have completed first".into(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Peek at model_type so we know which architecture to build.
|
||||||
|
let model_type = serde_json::from_str::<serde_json::Value>(&config_json)
|
||||||
|
.ok()
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|v| v.get("model_type"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let device = match Device::new_cuda(self.config.cuda_device as usize) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "cuda_unavailable".into(),
|
||||||
|
message: format!("Device::new_cuda({}) failed: {e}", self.config.cuda_device),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let paths: Vec<PathBuf> = safetensors_paths.into_iter().map(PathBuf::from).collect();
|
||||||
|
// SAFETY: same invariant as the single-GPU dense path — the HF
|
||||||
|
// cache files are treated as immutable while the mmap is held.
|
||||||
|
let vb = match unsafe { ShardedSafeTensors::var_builder(&paths, DType::BF16, &device) } {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("ShardedSafeTensors::var_builder: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Separate mmap of the same paths for the direct fused-region
|
||||||
|
// loader in `fused_load`. Linux's page cache shares the
|
||||||
|
// underlying pages between the two mmaps; the cost is one
|
||||||
|
// extra set of safetensors-header parses.
|
||||||
|
let mmap = match unsafe { candle_core::safetensors::MmapedSafetensors::multi(&paths) } {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("MmapedSafetensors::multi: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let loaded = match model_type.as_str() {
|
||||||
|
"qwen3" => {
|
||||||
|
let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse Qwen3 Config JSON: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match TpQwen3ForCausalLM::load(
|
||||||
|
&cfg,
|
||||||
|
&vb,
|
||||||
|
self.config.rank,
|
||||||
|
self.config.world_size,
|
||||||
|
comm,
|
||||||
|
) {
|
||||||
|
Ok(m) => WorkerModel::Qwen3(m),
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("TpQwen3ForCausalLM::load: {e:#}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"qwen3_5" => {
|
||||||
|
let cfg: qwen3_5_arch::Config = match serde_json::from_str(&config_json) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("parse Qwen3-Next Config JSON: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match TpQwen3_5ForCausalLM::load(
|
||||||
|
cfg,
|
||||||
|
&vb,
|
||||||
|
&mmap,
|
||||||
|
self.config.rank,
|
||||||
|
self.config.world_size,
|
||||||
|
comm,
|
||||||
|
quant_dtype,
|
||||||
|
) {
|
||||||
|
Ok(m) => WorkerModel::Qwen3_5(m),
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "load_failed".into(),
|
||||||
|
message: format!("TpQwen3_5ForCausalLM::load: {e:#}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "unsupported_arch".into(),
|
||||||
|
message: format!(
|
||||||
|
"worker: unsupported model_type '{other}' (supported: qwen3, qwen3_5)"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.models.insert(model_id.clone(), loaded);
|
||||||
|
tracing::info!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
model_type = %model_type,
|
||||||
|
"loaded TP shard"
|
||||||
|
);
|
||||||
|
WorkerResponse::LoadDenseShardOk
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_load_dense_shard(
|
||||||
|
&mut self,
|
||||||
|
_model_id: String,
|
||||||
|
_config_json: String,
|
||||||
|
_safetensors_paths: Vec<String>,
|
||||||
|
_quant: Option<String>,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "LoadDenseShard requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_generate_step(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
use candle_core::Tensor;
|
||||||
|
|
||||||
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
let device = model.device().clone();
|
||||||
|
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("build input tensor: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens.len(),
|
||||||
|
offset,
|
||||||
|
"worker GenerateStep: forward starting"
|
||||||
|
);
|
||||||
|
// Drop the resulting logits — the leader uses its own copy from
|
||||||
|
// rank 0. The forward's value here is the NCCL collectives it
|
||||||
|
// issues, which let the leader's rank-0 forward make progress.
|
||||||
|
if let Err(e) = model.forward(&input, offset) {
|
||||||
|
tracing::warn!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
error = %e,
|
||||||
|
"worker GenerateStep: forward failed"
|
||||||
|
);
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("TP forward: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
"worker GenerateStep: forward done"
|
||||||
|
);
|
||||||
|
WorkerResponse::GenerateStepOk
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_generate_step(
|
||||||
|
&mut self,
|
||||||
|
_model_id: &str,
|
||||||
|
_tokens: Vec<u32>,
|
||||||
|
_offset: usize,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "GenerateStep requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
|
||||||
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
model.clear_kv_cache();
|
||||||
|
WorkerResponse::KvCacheCleared
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_clear_kv_cache(&mut self, _model_id: &str) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "ClearKvCache requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_unload_model(&mut self, model_id: &str) -> WorkerResponse {
|
||||||
|
if self.models.remove(model_id).is_none() {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
tracing::info!(rank = self.config.rank, model = %model_id, "unloaded TP shard");
|
||||||
|
WorkerResponse::Unloaded
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_unload_model(&mut self, _model_id: &str) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "UnloadModel requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the
|
||||||
|
/// common ggml format names (case-insensitive). `None` and `Some("")`
|
||||||
|
/// both map to "no quantization".
|
||||||
|
///
|
||||||
|
/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`,
|
||||||
|
/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`,
|
||||||
|
/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore
|
||||||
|
/// is optional and the prefix is case-insensitive.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub(crate) fn parse_quant_string(
|
||||||
|
s: Option<&str>,
|
||||||
|
) -> anyhow::Result<Option<candle_core::quantized::GgmlDType>> {
|
||||||
|
use candle_core::quantized::GgmlDType;
|
||||||
|
let s = match s {
|
||||||
|
Some(s) if !s.is_empty() => s,
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
let normalised = s.to_ascii_lowercase().replace('_', "");
|
||||||
|
let dtype = match normalised.as_str() {
|
||||||
|
"q40" => GgmlDType::Q4_0,
|
||||||
|
"q41" => GgmlDType::Q4_1,
|
||||||
|
"q50" => GgmlDType::Q5_0,
|
||||||
|
"q51" => GgmlDType::Q5_1,
|
||||||
|
"q80" => GgmlDType::Q8_0,
|
||||||
|
"q81" => GgmlDType::Q8_1,
|
||||||
|
"q2k" => GgmlDType::Q2K,
|
||||||
|
"q3k" => GgmlDType::Q3K,
|
||||||
|
"q4k" | "q4km" => GgmlDType::Q4K,
|
||||||
|
"q5k" | "q5km" => GgmlDType::Q5K,
|
||||||
|
"q6k" => GgmlDType::Q6K,
|
||||||
|
"q8k" => GgmlDType::Q8K,
|
||||||
|
"f16" => GgmlDType::F16,
|
||||||
|
"bf16" => GgmlDType::BF16,
|
||||||
|
"f32" => GgmlDType::F32,
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \
|
||||||
|
q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
Ok(Some(dtype))
|
||||||
|
}
|
||||||
76
crates/neuron/src/health.rs
Normal file
76
crates/neuron/src/health.rs
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
//! Cached GPU health monitoring via periodic nvidia-smi polling.
|
||||||
|
|
||||||
|
use cortex_core::discovery::HealthResponse;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
const POLL_INTERVAL: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
|
/// Thread-safe cache for the latest GPU health reading.
|
||||||
|
pub struct HealthCache {
|
||||||
|
inner: RwLock<HealthResponse>,
|
||||||
|
has_gpus: RwLock<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for HealthCache {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HealthCache {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
inner: RwLock::new(HealthResponse {
|
||||||
|
uptime_secs: 0,
|
||||||
|
devices: vec![],
|
||||||
|
// The cache only owns the device-state half of /health;
|
||||||
|
// the api handler overlays activation from the tracker.
|
||||||
|
// Initialise with the default (Ready, empty lists) so a
|
||||||
|
// direct read from the cache stays a well-typed
|
||||||
|
// HealthResponse on the wire.
|
||||||
|
activation: Default::default(),
|
||||||
|
}),
|
||||||
|
has_gpus: RwLock::new(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark whether this node has GPUs (set after discovery).
|
||||||
|
pub async fn set_has_gpus(&self, has_gpus: bool) {
|
||||||
|
*self.has_gpus.write().await = has_gpus;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a snapshot of the current health state.
|
||||||
|
pub async fn snapshot(&self) -> HealthResponse {
|
||||||
|
self.inner.read().await.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run forever, polling nvidia-smi every 5 seconds and updating the cache.
|
||||||
|
pub async fn poll_loop(&self, start_time: Instant) {
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(POLL_INTERVAL).await;
|
||||||
|
|
||||||
|
let uptime = start_time.elapsed().as_secs();
|
||||||
|
|
||||||
|
if !*self.has_gpus.read().await {
|
||||||
|
let mut health = self.inner.write().await;
|
||||||
|
health.uptime_secs = uptime;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match crate::discovery::query_health().await {
|
||||||
|
Ok(devices) => {
|
||||||
|
let mut health = self.inner.write().await;
|
||||||
|
health.uptime_secs = uptime;
|
||||||
|
health.devices = devices;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "failed to poll GPU health");
|
||||||
|
// Keep last known reading, just update uptime.
|
||||||
|
let mut health = self.inner.write().await;
|
||||||
|
health.uptime_secs = uptime;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
9
crates/neuron/src/lib.rs
Normal file
9
crates/neuron/src/lib.rs
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
pub mod activation;
|
||||||
|
pub mod api;
|
||||||
|
pub mod config;
|
||||||
|
pub mod cuda;
|
||||||
|
pub mod discovery;
|
||||||
|
pub mod harness;
|
||||||
|
pub mod health;
|
||||||
|
pub mod startup;
|
||||||
|
pub mod wire;
|
||||||
251
crates/neuron/src/main.rs
Normal file
251
crates/neuron/src/main.rs
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
use neuron::{
|
||||||
|
activation, api,
|
||||||
|
config::NeuronConfig,
|
||||||
|
discovery,
|
||||||
|
harness::{HarnessRegistry, tp},
|
||||||
|
health, startup,
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
/// Top-level CLI. The same binary runs as either the public neuron
|
||||||
|
/// daemon (default), a tensor-parallel worker subprocess (when
|
||||||
|
/// `--worker` is set, spawned by the leader on the same host), or a
|
||||||
|
/// one-shot TP NCCL handshake check (when `--tp-smoke` is set).
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(name = "neuron")]
|
||||||
|
#[command(about = "Per-node daemon for cortex inference clusters")]
|
||||||
|
#[command(version)]
|
||||||
|
struct Args {
|
||||||
|
/// Run in tensor-parallel worker mode. The leader process spawns
|
||||||
|
/// one of these per non-zero NCCL rank and drives it over
|
||||||
|
/// newline-delimited JSON on stdin/stdout. Worker mode skips
|
||||||
|
/// discovery, the HTTP listener, and the health poller — it's a
|
||||||
|
/// pure RPC loop.
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
worker: bool,
|
||||||
|
|
||||||
|
/// Run a one-shot TP smoke test: spawn `--tp-size - 1` worker
|
||||||
|
/// subprocesses on `--cuda-devices`, build the NCCL communicator,
|
||||||
|
/// run an `AllReduce` sanity check across every rank, and exit.
|
||||||
|
/// Used to validate the TP plumbing in isolation from model load
|
||||||
|
/// and inference. Diagnostic-only — not exposed through the daemon
|
||||||
|
/// HTTP API.
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
tp_smoke: bool,
|
||||||
|
|
||||||
|
/// NCCL rank for worker mode. Ignored when `--worker` is not set.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
rank: u32,
|
||||||
|
|
||||||
|
/// Total NCCL world size for worker mode or TP smoke mode.
|
||||||
|
#[arg(long, default_value_t = 1)]
|
||||||
|
tp_size: u32,
|
||||||
|
|
||||||
|
/// CUDA device index for worker mode. Ignored when `--worker` is
|
||||||
|
/// not set.
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
cuda_device: u32,
|
||||||
|
|
||||||
|
/// Comma-separated CUDA device indices for TP smoke mode (one per
|
||||||
|
/// rank, starting with rank 0). Must have `tp_size` entries.
|
||||||
|
#[arg(long, value_delimiter = ',')]
|
||||||
|
cuda_devices: Vec<u32>,
|
||||||
|
|
||||||
|
/// Port to listen on (overrides config file). Daemon mode only.
|
||||||
|
#[arg(short, long)]
|
||||||
|
port: Option<u16>,
|
||||||
|
|
||||||
|
/// Path to the neuron config file. Daemon mode only.
|
||||||
|
#[arg(short, long, default_value = "neuron.toml")]
|
||||||
|
config: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.with_env_filter(
|
||||||
|
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
if args.worker {
|
||||||
|
return tp::worker::run(tp::worker::WorkerConfig {
|
||||||
|
rank: args.rank,
|
||||||
|
world_size: args.tp_size,
|
||||||
|
cuda_device: args.cuda_device,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.tp_smoke {
|
||||||
|
return tp_smoke(args.tp_size, args.cuda_devices).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
daemon(args).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One-shot tensor-parallel handshake. Spawns N-1 worker subprocesses
|
||||||
|
/// (rank 0 stays in this process), builds the NCCL communicator across
|
||||||
|
/// the full world, runs an AllReduce sanity check, and shuts everyone
|
||||||
|
/// down. Output is plain log lines on stderr + a final summary on
|
||||||
|
/// stdout in `key=value` form so an outer script can parse it.
|
||||||
|
async fn tp_smoke(tp_size: u32, cuda_devices: Vec<u32>) -> Result<()> {
|
||||||
|
if tp_size < 2 {
|
||||||
|
anyhow::bail!("--tp-size must be at least 2 (got {tp_size})");
|
||||||
|
}
|
||||||
|
if cuda_devices.len() as u32 != tp_size {
|
||||||
|
anyhow::bail!(
|
||||||
|
"--cuda-devices must list exactly {tp_size} entries (got {})",
|
||||||
|
cuda_devices.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
||||||
|
let leader_device = cuda_devices[0];
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
tp_size,
|
||||||
|
?cuda_devices,
|
||||||
|
binary = %exe.display(),
|
||||||
|
"tp-smoke: spawning worker pool"
|
||||||
|
);
|
||||||
|
// tp_smoke is a diagnostic tool; spawn the leader's device worker
|
||||||
|
// directly. (In the daemon path, CandleHarness::ensure_device_worker
|
||||||
|
// caches one per device.)
|
||||||
|
let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(leader_device)
|
||||||
|
.context("spawn leader device worker for tp-smoke")?;
|
||||||
|
let mut pool =
|
||||||
|
tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices, leader_worker.clone()).await?;
|
||||||
|
|
||||||
|
tracing::info!("tp-smoke: pinging every worker");
|
||||||
|
let pongs = pool.ping_all().await?;
|
||||||
|
for p in &pongs {
|
||||||
|
tracing::info!(?p, "tp-smoke: pong");
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(leader_device, "tp-smoke: initialising NCCL");
|
||||||
|
pool.init_nccl(leader_device).await?;
|
||||||
|
|
||||||
|
tracing::info!("tp-smoke: running AllReduce sanity check");
|
||||||
|
pool.nccl_sanity_check().await?;
|
||||||
|
|
||||||
|
tracing::info!("tp-smoke: shutting down pool");
|
||||||
|
pool.shutdown().await?;
|
||||||
|
|
||||||
|
println!("status=ok");
|
||||||
|
println!("tp_size={tp_size}");
|
||||||
|
println!(
|
||||||
|
"cuda_devices={}",
|
||||||
|
cuda_devices
|
||||||
|
.iter()
|
||||||
|
.map(|d| d.to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(",")
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn daemon(args: Args) -> Result<()> {
|
||||||
|
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
|
||||||
|
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
|
||||||
|
NeuronConfig::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
let port = args.port.unwrap_or(cfg.port);
|
||||||
|
let bind_url = format!("http://localhost:{port}");
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
tracing::info!("running hardware discovery");
|
||||||
|
let mut discovery_result = discovery::discover_system().await?;
|
||||||
|
tracing::info!(
|
||||||
|
hostname = %discovery_result.hostname,
|
||||||
|
devices = discovery_result.devices.len(),
|
||||||
|
"discovery complete"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Build harness registry from config. In-process harnesses (candle)
|
||||||
|
// need to know neuron's own bind URL so they can return it from
|
||||||
|
// inference_endpoint.
|
||||||
|
let registry = HarnessRegistry::from_configs(&cfg.harnesses, &bind_url, &cfg.harness);
|
||||||
|
discovery_result.harnesses = registry.names();
|
||||||
|
let candle = registry.candle();
|
||||||
|
|
||||||
|
let health_cache = Arc::new(health::HealthCache::new());
|
||||||
|
health_cache
|
||||||
|
.set_has_gpus(!discovery_result.devices.is_empty())
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let poller_cache = Arc::clone(&health_cache);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
poller_cache.poll_loop(start_time).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Track pre-warm progress so `/health` can tell callers whether
|
||||||
|
// configured default_models are still loading. Primed with the
|
||||||
|
// pending list now; the spawned task below flips entries through
|
||||||
|
// in_progress → completed/failed and finally toggles state=ready.
|
||||||
|
let activation = Arc::new(activation::ActivationTracker::new(&cfg.default_models));
|
||||||
|
|
||||||
|
let state = Arc::new(api::NeuronState {
|
||||||
|
discovery: discovery_result,
|
||||||
|
health_cache,
|
||||||
|
registry: RwLock::new(registry),
|
||||||
|
candle,
|
||||||
|
activation: Arc::clone(&activation),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Bind the HTTP listener BEFORE kicking off default_models loading.
|
||||||
|
// Previously load_default_models ran synchronously on this task,
|
||||||
|
// which delayed the bind by minutes for big TP models and made the
|
||||||
|
// host look down to anything probing `/health` during pre-warm.
|
||||||
|
// The pre-warm task runs in the background instead — `/health`
|
||||||
|
// surfaces its progress via the activation field.
|
||||||
|
let app = api::neuron_routes().with_state(Arc::clone(&state));
|
||||||
|
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
|
||||||
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
|
tracing::info!("neuron listening on {addr}");
|
||||||
|
|
||||||
|
if !cfg.default_models.is_empty() {
|
||||||
|
let state_for_prewarm = Arc::clone(&state);
|
||||||
|
let default_models = cfg.default_models.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Read lock held for the whole pre-warm run. The unload
|
||||||
|
// path takes the same read lock per call (no writers) and
|
||||||
|
// serialises through the candle harness's own internal
|
||||||
|
// mutex, so concurrent on-demand loads and pre-warm loads
|
||||||
|
// do not race on the same model.
|
||||||
|
let registry = state_for_prewarm.registry.read().await;
|
||||||
|
startup::load_default_models(®istry, &default_models, &state_for_prewarm.activation)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
axum::serve(listener, app)
|
||||||
|
.with_graceful_shutdown(startup::shutdown_signal())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Deactivation: serve has returned (graceful shutdown signal
|
||||||
|
// received and connections drained). Release CUDA contexts / VRAM
|
||||||
|
// by unloading every model before exiting; systemd's TimeoutStopSec
|
||||||
|
// bounds how long this phase may take.
|
||||||
|
let registry = state.registry.read().await;
|
||||||
|
startup::unload_all_models(®istry).await;
|
||||||
|
tracing::info!("shutdown complete");
|
||||||
|
// Fast-exit instead of returning. Returning lets `#[tokio::main]`
|
||||||
|
// drop the runtime, which in turn waits on the blocking thread
|
||||||
|
// pool to drain. After a CUDA driver error (OOM → illegal address)
|
||||||
|
// a spawn_blocking thread can be wedged inside `cuCtxGetCurrent`,
|
||||||
|
// and tokio's drain has no timeout. systemd then SIGABRTs us and
|
||||||
|
// dumps core. Skipping the drain hands the OS a clean exit code;
|
||||||
|
// the OS reaps the stuck threads. See the 2026-05-26 incident
|
||||||
|
// captured under "Stack trace of thread 2951308" in the journal.
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
176
crates/neuron/src/startup.rs
Normal file
176
crates/neuron/src/startup.rs
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
//! Activation- and deactivation-time orchestration.
|
||||||
|
//!
|
||||||
|
//! Wired from `main.rs` around the HTTP listener — activation runs
|
||||||
|
//! before bind, deactivation runs after axum returns from its
|
||||||
|
//! graceful-shutdown future. Kept in its own module so the logic is
|
||||||
|
//! unit-testable without spinning up a full neuron process.
|
||||||
|
|
||||||
|
use crate::activation::ActivationTracker;
|
||||||
|
use crate::harness::HarnessRegistry;
|
||||||
|
use crate::harness::preflight::PreflightError;
|
||||||
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::signal;
|
||||||
|
|
||||||
|
/// Maximum time we wait on a single `unload_model` call during
|
||||||
|
/// shutdown. The TP unload path tries `Arc::try_unwrap`, which fails
|
||||||
|
/// fast when an inference is in flight, so a healthy unload returns
|
||||||
|
/// in milliseconds. The timeout exists to bound a *future* unload
|
||||||
|
/// path that might genuinely block on a stuck worker, so a single
|
||||||
|
/// wedged model can't burn the whole systemd TimeoutStopSec window.
|
||||||
|
const UNLOAD_TIMEOUT: Duration = Duration::from_secs(20);
|
||||||
|
|
||||||
|
/// Load each spec sequentially against the registry, treating
|
||||||
|
/// individual failures as warnings rather than fatal errors.
|
||||||
|
///
|
||||||
|
/// VRAM contention makes parallel loads risky; the sequential path is
|
||||||
|
/// boring but correct. The function logs elapsed time per load and
|
||||||
|
/// updates `activation` so the `/health` endpoint can tell callers
|
||||||
|
/// which models are still pre-warming. Caller is expected to run this
|
||||||
|
/// in a background `tokio::spawn` task — the HTTP listener binds
|
||||||
|
/// independently so the host is reachable during the pre-warm window.
|
||||||
|
pub async fn load_default_models(
|
||||||
|
registry: &HarnessRegistry,
|
||||||
|
specs: &[ModelSpec],
|
||||||
|
activation: &ActivationTracker,
|
||||||
|
) {
|
||||||
|
if specs.is_empty() {
|
||||||
|
activation.mark_ready().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tracing::info!(count = specs.len(), "loading default models");
|
||||||
|
for spec in specs {
|
||||||
|
let start = Instant::now();
|
||||||
|
activation.start_loading(&spec.model_id).await;
|
||||||
|
match registry.load_model(spec).await {
|
||||||
|
Ok(()) => {
|
||||||
|
activation.complete_loading(&spec.model_id).await;
|
||||||
|
tracing::info!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"loaded default model"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let rendered = format!("{e:#}");
|
||||||
|
activation.fail_loading(&spec.model_id, &rendered).await;
|
||||||
|
// When the underlying failure is a preflight rejection,
|
||||||
|
// pull the structured fields out so journalctl shows
|
||||||
|
// `reason=tp_requires_safetensors detail="..."` instead
|
||||||
|
// of an opaque "fetch config.json … 404". The operator
|
||||||
|
// can act on the structured form directly.
|
||||||
|
if let Some(pf) = e.downcast_ref::<PreflightError>() {
|
||||||
|
tracing::warn!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
reason = preflight_kind(pf),
|
||||||
|
detail = %pf,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"failed to load default model, continuing"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
model = %spec.model_id,
|
||||||
|
error = %rendered,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"failed to load default model, continuing"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
activation.mark_ready().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Short kebab-case tag for a preflight failure. Used as a structured
|
||||||
|
/// log field so journalctl filtering can match on the failure class
|
||||||
|
/// (`reason=tp_requires_safetensors`, `reason=quant_not_found`, etc.).
|
||||||
|
fn preflight_kind(err: &PreflightError) -> &'static str {
|
||||||
|
match err {
|
||||||
|
PreflightError::RepoFetchFailed { .. } => "repo_fetch_failed",
|
||||||
|
PreflightError::EmptyRepo { .. } => "empty_repo",
|
||||||
|
PreflightError::TpRequiresSafetensors { .. } => "tp_requires_safetensors",
|
||||||
|
PreflightError::QuantNotFound { .. } => "quant_not_found",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Future that resolves on SIGINT (Ctrl-C) or SIGTERM (systemd stop).
|
||||||
|
///
|
||||||
|
/// Wired into `axum::serve(...).with_graceful_shutdown(shutdown_signal())`
|
||||||
|
/// so the HTTP listener stops accepting new connections, lets in-flight
|
||||||
|
/// requests drain, and then yields control back to main for cleanup.
|
||||||
|
pub async fn shutdown_signal() {
|
||||||
|
let ctrl_c = async {
|
||||||
|
signal::ctrl_c().await.ok();
|
||||||
|
};
|
||||||
|
let terminate = async {
|
||||||
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||||
|
.expect("install SIGTERM handler")
|
||||||
|
.recv()
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
tokio::select! {
|
||||||
|
_ = ctrl_c => tracing::info!("received SIGINT, shutting down"),
|
||||||
|
_ = terminate => tracing::info!("received SIGTERM, shutting down"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unload every model currently registered. Called from `main.rs` after
|
||||||
|
/// axum's graceful shutdown future resolves, so CUDA contexts and VRAM
|
||||||
|
/// are released before the process exits rather than left to the OS to
|
||||||
|
/// reclaim. Per-model failures are logged and skipped — keep cleanup
|
||||||
|
/// going even when one harness is unhealthy.
|
||||||
|
pub async fn unload_all_models(registry: &HarnessRegistry) {
|
||||||
|
let listed = match registry.list_all_models().await {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "failed to list models during shutdown");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if listed.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(count = listed.len(), "unloading models for shutdown");
|
||||||
|
let mut stuck = 0;
|
||||||
|
for model in listed {
|
||||||
|
let start = Instant::now();
|
||||||
|
match tokio::time::timeout(UNLOAD_TIMEOUT, registry.unload_model(&model.id)).await {
|
||||||
|
Ok(Ok(())) => tracing::info!(
|
||||||
|
model = %model.id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"unloaded"
|
||||||
|
),
|
||||||
|
// Most common shape today: TP unload bails because an
|
||||||
|
// inference is still mid-flight (the spawned task holds
|
||||||
|
// an `Arc<TpLoadedModel>` clone). Promoted from warn to
|
||||||
|
// error and tagged with the request-state so the operator
|
||||||
|
// can correlate with the chat_completion logs above.
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
stuck += 1;
|
||||||
|
tracing::error!(
|
||||||
|
model = %model.id,
|
||||||
|
error = %e,
|
||||||
|
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||||
|
"unload failed during shutdown"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
stuck += 1;
|
||||||
|
tracing::error!(
|
||||||
|
model = %model.id,
|
||||||
|
timeout_secs = UNLOAD_TIMEOUT.as_secs(),
|
||||||
|
"unload timed out during shutdown, continuing"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if stuck > 0 {
|
||||||
|
tracing::error!(
|
||||||
|
stuck,
|
||||||
|
"shutdown leaving {stuck} model(s) loaded; VRAM will be \
|
||||||
|
reclaimed by the OS on process exit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
306
crates/neuron/src/wire/event.rs
Normal file
306
crates/neuron/src/wire/event.rs
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
//! Format-agnostic inference event stream.
|
||||||
|
//!
|
||||||
|
//! The candle harness emits a sequence of these for every streaming
|
||||||
|
//! request. Wire-format projections in sibling modules
|
||||||
|
//! ([`super::openai_chat`], the eventual `openai_responses` /
|
||||||
|
//! `anthropic_messages` projections) read this stream and produce
|
||||||
|
//! the chunks / events their HTTP clients expect.
|
||||||
|
//!
|
||||||
|
//! Design notes:
|
||||||
|
//!
|
||||||
|
//! - [`Start`] carries no token of its own. It only signals "the
|
||||||
|
//! model has accepted the prompt and is about to begin emitting
|
||||||
|
//! text". OpenAI chat materialises this as a `role: assistant`
|
||||||
|
//! chunk; OpenAI Responses as the `response.created` +
|
||||||
|
//! `response.output_item.added` pair; Anthropic as
|
||||||
|
//! `message_start`. All three of those would otherwise have to
|
||||||
|
//! peek at the *first* token to know when to emit, which couples
|
||||||
|
//! the wire layer to the producer's pacing.
|
||||||
|
//! - [`TextDelta`] is *visible* output. Reasoning / `<think>`
|
||||||
|
//! blocks go through a future [`ReasoningDelta`] variant once
|
||||||
|
//! the harness learns to split them (today they pass through as
|
||||||
|
//! plain text inside `TextDelta`; helexa-acp picks them apart on
|
||||||
|
//! the consumer side).
|
||||||
|
//! - [`Finish`] is the only place a stream is allowed to end
|
||||||
|
//! cleanly. Projections rely on this to emit final usage
|
||||||
|
//! bookkeeping; absence means the producer crashed and the
|
||||||
|
//! consumer should treat the stream as truncated.
|
||||||
|
//!
|
||||||
|
//! [`Start`]: InferenceEvent::Start
|
||||||
|
//! [`TextDelta`]: InferenceEvent::TextDelta
|
||||||
|
//! [`Finish`]: InferenceEvent::Finish
|
||||||
|
|
||||||
|
/// One unit of output from the inference loop.
|
||||||
|
///
|
||||||
|
/// Producers send these on an `mpsc::Sender<InferenceEvent>`;
|
||||||
|
/// projection layers in sibling modules consume them and emit
|
||||||
|
/// wire-format-specific frames downstream.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum InferenceEvent {
|
||||||
|
/// The producer has accepted the prompt and is about to emit
|
||||||
|
/// the first token. Sent at most once per stream.
|
||||||
|
Start,
|
||||||
|
/// A piece of visible assistant text. Multiple deltas
|
||||||
|
/// concatenate into the complete reply.
|
||||||
|
TextDelta(String),
|
||||||
|
/// Reasoning / scratchpad text the model emitted inside a
|
||||||
|
/// `<think>` block (or equivalent). The harness routes
|
||||||
|
/// content between marker tokens here so wire projectors can
|
||||||
|
/// decide what to do with it (chat completions drops by
|
||||||
|
/// default; Responses API has a dedicated event family).
|
||||||
|
ReasoningDelta(String),
|
||||||
|
/// A tool call has been parsed out of a `<tool_call>{json}</tool_call>`
|
||||||
|
/// block. Carries the parsed name + arguments JSON string
|
||||||
|
/// (Anthropic / OpenAI projectors emit their own wire shape
|
||||||
|
/// from this).
|
||||||
|
///
|
||||||
|
/// `index` is the call slot — incremented per tool call in a
|
||||||
|
/// turn so wire formats that order calls by index
|
||||||
|
/// (OpenAI chat completions) can correlate.
|
||||||
|
ToolCall {
|
||||||
|
index: usize,
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
/// Complete JSON arguments string. The model could in
|
||||||
|
/// principle stream these token-by-token, but our
|
||||||
|
/// extraction buffers the whole block until `</tool_call>`
|
||||||
|
/// arrives and emits exactly one event per call.
|
||||||
|
arguments: String,
|
||||||
|
},
|
||||||
|
/// The stream is complete. Carries the reason so wire formats
|
||||||
|
/// that use it (OpenAI's `finish_reason`, Anthropic's
|
||||||
|
/// `stop_reason`) can render it without re-parsing.
|
||||||
|
Finish { reason: FinishReason },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Why a stream stopped. Stays small on purpose — anything that
|
||||||
|
/// doesn't map cleanly to one of these collapses to [`Stop`].
|
||||||
|
///
|
||||||
|
/// Mappings to wire formats:
|
||||||
|
///
|
||||||
|
/// | variant | OpenAI `finish_reason` | OpenAI Responses `status` | Anthropic `stop_reason` |
|
||||||
|
/// |---------|------------------------|---------------------------|-------------------------|
|
||||||
|
/// | `Stop` | `"stop"` | `"completed"` | `"end_turn"` |
|
||||||
|
/// | `Length`| `"length"` | `"incomplete"` | `"max_tokens"` |
|
||||||
|
/// | `ToolCalls` | `"tool_calls"` | `"completed"` | `"tool_use"` |
|
||||||
|
///
|
||||||
|
/// [`Stop`]: FinishReason::Stop
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum FinishReason {
|
||||||
|
/// Model emitted EOS naturally.
|
||||||
|
Stop,
|
||||||
|
/// Hit `max_tokens` before EOS.
|
||||||
|
Length,
|
||||||
|
/// Stopped because the model called a tool and is waiting for
|
||||||
|
/// the result. Not yet emitted by the candle harness —
|
||||||
|
/// reserved for the day tool-call extraction lands.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
ToolCalls,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FinishReason {
|
||||||
|
/// String form used by OpenAI chat completions and OpenAI
|
||||||
|
/// completions. Wire modules can call this directly or do their
|
||||||
|
/// own mapping for non-string formats.
|
||||||
|
pub fn as_openai_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
FinishReason::Stop => "stop",
|
||||||
|
FinishReason::Length => "length",
|
||||||
|
FinishReason::ToolCalls => "tool_calls",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Open/close token IDs for the reasoning marker a loaded model uses
|
||||||
|
/// (or `None` for non-reasoning models). The harness reads this once
|
||||||
|
/// at load time from the tokenizer's added-tokens table, then the
|
||||||
|
/// inference loop checks `next_token` against the pair to flip
|
||||||
|
/// between [`InferenceEvent::TextDelta`] and
|
||||||
|
/// [`InferenceEvent::ReasoningDelta`].
|
||||||
|
///
|
||||||
|
/// `open` and `close` text are kept alongside the IDs so wire
|
||||||
|
/// projectors that want to re-emit the literal markers (the
|
||||||
|
/// opt-in `include_thinking` path on chat completions) don't have
|
||||||
|
/// to reach back into the tokenizer for the strings.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ReasoningTokenPair {
|
||||||
|
pub open_id: u32,
|
||||||
|
pub close_id: u32,
|
||||||
|
pub open_text: String,
|
||||||
|
pub close_text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Known reasoning-marker conventions. Each is a `(open, close)`
|
||||||
|
/// pair of literal token strings. Each modern reasoning model
|
||||||
|
/// declares its markers in the tokenizer's `added_tokens` table;
|
||||||
|
/// at load time we probe for whichever pair the loaded tokenizer
|
||||||
|
/// has and stash both IDs.
|
||||||
|
///
|
||||||
|
/// Ordering matters only for tie-breaking when a model declares
|
||||||
|
/// multiple pairs (shouldn't happen in practice); the first hit
|
||||||
|
/// wins.
|
||||||
|
const KNOWN_REASONING_MARKERS: &[(&str, &str)] = &[
|
||||||
|
// Qwen3, DeepSeek-R1, gpt-oss, and most other open-weight
|
||||||
|
// reasoning models.
|
||||||
|
("<think>", "</think>"),
|
||||||
|
// Mistral Magistral.
|
||||||
|
("[THINK]", "[/THINK]"),
|
||||||
|
// Some older derivatives; harmless to probe.
|
||||||
|
("<thought>", "</thought>"),
|
||||||
|
("<reasoning>", "</reasoning>"),
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Open/close token IDs for the model's tool-call marker
|
||||||
|
/// convention (or `None` for models that don't emit structured
|
||||||
|
/// tool calls). Same shape as [`ReasoningTokenPair`]: probed once
|
||||||
|
/// at load time, consumed by the inference loop to switch between
|
||||||
|
/// "emit visible deltas" and "buffer JSON for the next tool
|
||||||
|
/// call".
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolCallTokenPair {
|
||||||
|
pub open_id: u32,
|
||||||
|
pub close_id: u32,
|
||||||
|
pub open_text: String,
|
||||||
|
pub close_text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool-call marker conventions. Open-weight tool-use models
|
||||||
|
/// converged on `<tool_call>` / `</tool_call>` (Qwen3-Coder /
|
||||||
|
/// -Instruct, the Hermes function-call format, DeepSeek-Coder,
|
||||||
|
/// gpt-oss). The pair lives alongside the reasoning markers in
|
||||||
|
/// the same `added_tokens` table.
|
||||||
|
const KNOWN_TOOL_CALL_MARKERS: &[(&str, &str)] = &[("<tool_call>", "</tool_call>")];
|
||||||
|
|
||||||
|
/// Probe a tokenizer for known tool-call marker pairs. Mirrors
|
||||||
|
/// [`detect_reasoning_token_pair`] — both open AND close must
|
||||||
|
/// resolve for the pair to be returned. `None` means the model
|
||||||
|
/// doesn't emit structured tool calls (or its tokenizer split
|
||||||
|
/// the markers across tokens).
|
||||||
|
pub fn detect_tool_call_token_pair<F>(token_to_id: F) -> Option<ToolCallTokenPair>
|
||||||
|
where
|
||||||
|
F: Fn(&str) -> Option<u32>,
|
||||||
|
{
|
||||||
|
for (open_text, close_text) in KNOWN_TOOL_CALL_MARKERS {
|
||||||
|
let open_id = token_to_id(open_text);
|
||||||
|
let close_id = token_to_id(close_text);
|
||||||
|
if let (Some(open_id), Some(close_id)) = (open_id, close_id) {
|
||||||
|
return Some(ToolCallTokenPair {
|
||||||
|
open_id,
|
||||||
|
close_id,
|
||||||
|
open_text: (*open_text).into(),
|
||||||
|
close_text: (*close_text).into(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inspect a tokenizer for known reasoning-marker pairs and return
|
||||||
|
/// the first match. The tokenizer types this trait is defined over
|
||||||
|
/// just need to expose `token_to_id(&str) -> Option<u32>` so this
|
||||||
|
/// stays decoupled from the candle crate — the production caller
|
||||||
|
/// passes a `tokenizers::Tokenizer`, but tests can fake one.
|
||||||
|
///
|
||||||
|
/// Returns `None` when no known marker pair is fully declared
|
||||||
|
/// (both open AND close token ids must resolve). That's the
|
||||||
|
/// pass-through case — non-reasoning models, or reasoning models
|
||||||
|
/// whose tokenizer split the markers across multiple tokens (rare
|
||||||
|
/// in practice; modern reasoning tokenizers list them as
|
||||||
|
/// `added_tokens`).
|
||||||
|
pub fn detect_reasoning_token_pair<F>(token_to_id: F) -> Option<ReasoningTokenPair>
|
||||||
|
where
|
||||||
|
F: Fn(&str) -> Option<u32>,
|
||||||
|
{
|
||||||
|
for (open_text, close_text) in KNOWN_REASONING_MARKERS {
|
||||||
|
let open_id = token_to_id(open_text);
|
||||||
|
let close_id = token_to_id(close_text);
|
||||||
|
if let (Some(open_id), Some(close_id)) = (open_id, close_id) {
|
||||||
|
return Some(ReasoningTokenPair {
|
||||||
|
open_id,
|
||||||
|
close_id,
|
||||||
|
open_text: (*open_text).into(),
|
||||||
|
close_text: (*close_text).into(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
fn lookup<'a>(map: &'a HashMap<&'static str, u32>) -> impl Fn(&str) -> Option<u32> + 'a {
|
||||||
|
|s| map.get(s).copied()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detects_qwen3_style_think_markers() {
|
||||||
|
let mut m = HashMap::new();
|
||||||
|
m.insert("<think>", 151648);
|
||||||
|
m.insert("</think>", 151649);
|
||||||
|
let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected");
|
||||||
|
assert_eq!(pair.open_id, 151648);
|
||||||
|
assert_eq!(pair.close_id, 151649);
|
||||||
|
assert_eq!(pair.open_text, "<think>");
|
||||||
|
assert_eq!(pair.close_text, "</think>");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detects_mistral_magistral_markers() {
|
||||||
|
let mut m = HashMap::new();
|
||||||
|
m.insert("[THINK]", 100);
|
||||||
|
m.insert("[/THINK]", 101);
|
||||||
|
let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected");
|
||||||
|
assert_eq!(pair.open_text, "[THINK]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn returns_none_when_only_open_marker_present() {
|
||||||
|
// A pathological tokenizer that has `<think>` but not
|
||||||
|
// `</think>` shouldn't half-detect. Pass-through.
|
||||||
|
let mut m = HashMap::new();
|
||||||
|
m.insert("<think>", 1);
|
||||||
|
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn returns_none_for_non_reasoning_tokenizer() {
|
||||||
|
let m: HashMap<&'static str, u32> = HashMap::new();
|
||||||
|
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detects_tool_call_markers() {
|
||||||
|
let mut m = HashMap::new();
|
||||||
|
m.insert("<tool_call>", 151657);
|
||||||
|
m.insert("</tool_call>", 151658);
|
||||||
|
let pair = detect_tool_call_token_pair(lookup(&m)).expect("pair detected");
|
||||||
|
assert_eq!(pair.open_id, 151657);
|
||||||
|
assert_eq!(pair.close_id, 151658);
|
||||||
|
assert_eq!(pair.open_text, "<tool_call>");
|
||||||
|
assert_eq!(pair.close_text, "</tool_call>");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn returns_none_for_non_tool_use_tokenizer() {
|
||||||
|
let m: HashMap<&'static str, u32> = HashMap::new();
|
||||||
|
assert!(detect_tool_call_token_pair(lookup(&m)).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn first_match_wins_when_multiple_pairs_declared() {
|
||||||
|
// Hypothetical tokenizer with both Qwen-style AND Mistral-style
|
||||||
|
// markers — the `<think>` pair is earlier in the convention
|
||||||
|
// table so it wins.
|
||||||
|
let mut m = HashMap::new();
|
||||||
|
m.insert("<think>", 1);
|
||||||
|
m.insert("</think>", 2);
|
||||||
|
m.insert("[THINK]", 3);
|
||||||
|
m.insert("[/THINK]", 4);
|
||||||
|
let pair = detect_reasoning_token_pair(lookup(&m)).unwrap();
|
||||||
|
assert_eq!(pair.open_id, 1);
|
||||||
|
assert_eq!(pair.close_id, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user