Compare commits
144 Commits
v0.1.16
...
phase-2-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
61adff347a
|
|||
|
0af8c8d6e7
|
|||
|
435fd10902
|
|||
|
cb303832bc
|
|||
|
44008358c5
|
|||
|
2f387f33f8
|
|||
|
fc9a8c42a3
|
|||
|
7733eecba5
|
|||
|
fdc0adb738
|
|||
|
8fa1d1962e
|
|||
|
cad7552104
|
|||
|
1818dfb337
|
|||
|
5ed1140c97
|
|||
|
957f704efa
|
|||
|
1859777332
|
|||
|
6927286cab
|
|||
|
302ccfb982
|
|||
|
df0abfe4d4
|
|||
|
b9016571f6
|
|||
|
adbc52bfcd
|
|||
|
537a0fe7f2
|
|||
|
cbadfcf112
|
|||
|
3ecbb21ece
|
|||
|
0d841a4981
|
|||
|
0bbb9b752d
|
|||
|
5aac1ffc59
|
|||
|
ec2b6450b2
|
|||
|
a494c8d43c
|
|||
|
abbedf8d8a
|
|||
|
6cc14e925c
|
|||
|
1c16732668
|
|||
|
5a0861d639
|
|||
|
33652ac651
|
|||
|
c297a54074
|
|||
|
0121a1930f
|
|||
|
13f4c36aeb
|
|||
|
4a51a54554
|
|||
|
0609f1ac5d
|
|||
|
96fc379893
|
|||
|
e267f583e1
|
|||
|
e23d5011d0
|
|||
|
249b2e5c98
|
|||
|
c59da83636
|
|||
|
f05882369d
|
|||
|
bd04d7f580
|
|||
|
1e13889392
|
|||
|
6e1c1dd0fc
|
|||
|
35876954cd
|
|||
|
740299bd9d
|
|||
|
cdf0f4e66d
|
|||
|
c4954e0eed
|
|||
|
b4f3576d82
|
|||
|
76ab24d98c
|
|||
|
b179204fd3
|
|||
|
081b532387
|
|||
|
7c19da9361
|
|||
|
24e20dcb5c
|
|||
|
becf61b9c1
|
|||
|
b9e7a76a7a
|
|||
|
800498f530
|
|||
|
d3f2d50749
|
|||
|
2740e61a23
|
|||
|
67f79c868f
|
|||
|
fc6ef0ee0f
|
|||
|
1385979e3d
|
|||
|
0a1cfcd4d0
|
|||
|
ea0e0f7911
|
|||
|
aa88d37509
|
|||
|
0f00f72b47
|
|||
|
9b0ed0b57f
|
|||
|
dc2a803266
|
|||
|
e71181499e
|
|||
|
ee663e5e99
|
|||
|
34f9b77d9d
|
|||
|
f084aaab8e
|
|||
|
68a606a79c
|
|||
|
4aa71902d0
|
|||
|
bef159b21c
|
|||
|
8d7b099b36
|
|||
|
89d98d1fb2
|
|||
|
cc95fe28d9
|
|||
|
09c945f81e
|
|||
|
05dc0bad18
|
|||
|
10c151efa5
|
|||
|
44ae927e38
|
|||
|
1ebbe87651
|
|||
|
70eb6af42b
|
|||
|
d1a4aad91d
|
|||
|
95dc8745eb
|
|||
|
495d3f7c05
|
|||
|
5c4c8e0eba
|
|||
|
07c44d5db1
|
|||
|
e7eb3dab6a
|
|||
|
180274548d
|
|||
|
a70f317729
|
|||
|
c6022aa6b9
|
|||
|
9e31d8deca
|
|||
|
b400e8b704
|
|||
|
62ca125a68
|
|||
|
735945ee81
|
|||
|
f72dee094f
|
|||
|
d46d8d4f6c
|
|||
|
9b8bd146f6
|
|||
|
96d8755245
|
|||
|
12549c9aed
|
|||
|
46527d7804
|
|||
|
8d3194f992
|
|||
|
5436af9c73
|
|||
|
8e882c0757
|
|||
|
93421f48e2
|
|||
|
05e15f3597
|
|||
|
da068ded6d
|
|||
|
2a7ede0232
|
|||
|
18ae3c30ee
|
|||
|
1a0400131e
|
|||
|
1866b99a89
|
|||
|
60176e7c2e
|
|||
|
602e8e1471
|
|||
|
e9d0a75dd5
|
|||
|
6cf87e328f
|
|||
|
f9f5fa41b6
|
|||
|
ed4d71db09
|
|||
|
39010c779f
|
|||
|
57d7ef8d3c
|
|||
|
0e9671dd7d
|
|||
|
e29c9e35f0
|
|||
|
8a2334eacb
|
|||
|
aad314cdfa
|
|||
|
6779b7526a
|
|||
|
84f5662df1
|
|||
|
249c9442e8
|
|||
|
5e17081fb4
|
|||
|
03bed93fee
|
|||
|
4a5211d830
|
|||
|
6d2dc5ff1a
|
|||
|
b713dbe669
|
|||
|
5c957d08ec
|
|||
|
729317d1ef
|
|||
|
5c2bd1a1da
|
|||
|
3cccc2c56b
|
|||
|
7f797b0265
|
|||
|
5a0360c1d5
|
|||
|
472c0e8737
|
|||
|
|
b9d8e30058 |
343
.gitea/workflows/build-prerelease.yml
Normal file
343
.gitea/workflows/build-prerelease.yml
Normal file
@@ -0,0 +1,343 @@
|
||||
name: build-prerelease
|
||||
|
||||
# Manually-dispatched workflow that builds CUDA-flavoured neuron binaries
|
||||
# (and a single cortex binary), packages each as a Fedora RPM, signs
|
||||
# them, and publishes to the `unstable` channel at rpm.lair.cafe.
|
||||
#
|
||||
# Trigger from the Gitea UI: Actions → build-prerelease → Run workflow.
|
||||
# Optionally provide a `ref` to build from a non-default branch.
|
||||
#
|
||||
# The published packages are versioned as e.g.
|
||||
# helexa-neuron-blackwell-0.1.16-0.1.20260518T140530.gitabcdef0.fc43.x86_64
|
||||
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
|
||||
# commit time (s) commit sha
|
||||
# so they sort BELOW the eventual 0.1.16-1 stable release, and so two
|
||||
# commits on the same day are still strictly ordered by their commit
|
||||
# timestamps (rather than by RPM-vercmp's alpha-vs-digit precedence
|
||||
# on the SHA fragment).
|
||||
|
||||
on:
|
||||
# Auto-build on every push to main so the unstable channel tracks
|
||||
# head without a manual dispatch step.
|
||||
push:
|
||||
branches: [main]
|
||||
# Manual dispatch still available to build from a non-main ref.
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
ref:
|
||||
description: "Git ref to build (branch / tag / commit). Defaults to the workflow's branch."
|
||||
required: false
|
||||
default: ""
|
||||
|
||||
concurrency:
|
||||
# Share the group with ci.yml so the two workflows can't run
|
||||
# concurrently on the same `rust` runner (act reuses the workspace
|
||||
# cache and races destroy each other's build files mid-compile).
|
||||
# cancel-in-progress=false → workflows queue; if a newer push lands,
|
||||
# the older run is still picked up by ci.yml's own ref-keyed
|
||||
# concurrency (same group, queued).
|
||||
group: cortex-runner-pool-${{ github.ref }}
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
CARGO_INCREMENTAL: "0"
|
||||
CARGO_TERM_COLOR: "always"
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
name: Resolve version stamps
|
||||
runs-on: rust
|
||||
outputs:
|
||||
version: ${{ steps.info.outputs.version }}
|
||||
release: ${{ steps.info.outputs.release }}
|
||||
short_sha: ${{ steps.info.outputs.short_sha }}
|
||||
commit_timestamp: ${{ steps.info.outputs.commit_timestamp }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
- id: info
|
||||
run: |
|
||||
set -eux
|
||||
VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml)
|
||||
SHORT_SHA=$(git rev-parse --short=7 HEAD)
|
||||
# Second-precise commit timestamp gives the release stamp a
|
||||
# strictly monotonic numeric prefix. The earlier %Y%m%d-only
|
||||
# form let same-day builds be ordered by RPM's rpmvercmp
|
||||
# rules over the SHA, which is non-chronological — e.g.
|
||||
# "git602e8e1" sorts newer than "gitf9f5fa4" purely because
|
||||
# rpmvercmp ranks digit-prefixed segments above alpha ones.
|
||||
# The SHA stays only as a debug identifier; sort order is
|
||||
# decided entirely by the timestamp.
|
||||
COMMIT_TIMESTAMP=$(git log -1 --format=%cd --date=format:%Y%m%d%H%M%S HEAD)
|
||||
RELEASE="0.1.${COMMIT_TIMESTAMP}.git${SHORT_SHA}"
|
||||
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||
echo "release=${RELEASE}" >> "$GITHUB_OUTPUT"
|
||||
echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
echo "commit_timestamp=${COMMIT_TIMESTAMP}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-cortex:
|
||||
name: Build cortex binary
|
||||
needs: prepare
|
||||
# runner-rust image already provides rust/cargo/clippy/rustfmt via
|
||||
# dnf — no rustup install step needed.
|
||||
runs-on: rust
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- name: Build cortex (release)
|
||||
run: cargo build --release -p cortex-cli
|
||||
|
||||
- name: Stage binary
|
||||
run: |
|
||||
mkdir --parents artifacts
|
||||
cp target/release/cortex artifacts/cortex
|
||||
./artifacts/cortex --version || true
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: cortex-fc43
|
||||
path: artifacts/cortex
|
||||
retention-days: 1
|
||||
|
||||
build-neuron:
|
||||
name: Build neuron-${{ matrix.flavour }}
|
||||
needs: prepare
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- flavour: ampere
|
||||
compute_cap: "86"
|
||||
runner: cuda-13.0
|
||||
cuda_home: /usr/local/cuda-13.0
|
||||
build_jobs: 8
|
||||
nvcc_threads: 4
|
||||
cargo_features: "cuda cudnn flash-attn"
|
||||
- flavour: ada
|
||||
compute_cap: "89"
|
||||
runner: cuda-13.0
|
||||
cuda_home: /usr/local/cuda-13.0
|
||||
build_jobs: 8
|
||||
nvcc_threads: 4
|
||||
cargo_features: "cuda cudnn flash-attn"
|
||||
- flavour: blackwell
|
||||
compute_cap: "120"
|
||||
runner: cuda-13.0
|
||||
cuda_home: /usr/local/cuda-13.0
|
||||
build_jobs: 8
|
||||
nvcc_threads: 4
|
||||
cargo_features: "cuda cudnn flash-attn"
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- name: Build neuron with CUDA (${{ matrix.flavour }})
|
||||
run: |
|
||||
set -eux
|
||||
export PATH="${{ matrix.cuda_home }}/bin:${PATH}"
|
||||
export LD_LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LD_LIBRARY_PATH:-}"
|
||||
export LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LIBRARY_PATH:-}"
|
||||
cargo build --release -p neuron --features "${{ matrix.cargo_features }}"
|
||||
env:
|
||||
CUDA_COMPUTE_CAP: ${{ matrix.compute_cap }}
|
||||
CARGO_BUILD_JOBS: ${{ matrix.build_jobs }}
|
||||
NVCC_THREADS: ${{ matrix.nvcc_threads }}
|
||||
|
||||
- name: Stage binary
|
||||
run: |
|
||||
mkdir --parents artifacts
|
||||
cp target/release/neuron artifacts/neuron-${{ matrix.flavour }}
|
||||
file "artifacts/neuron-${{ matrix.flavour }}"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: neuron-${{ matrix.flavour }}-fc43
|
||||
path: artifacts/neuron-${{ matrix.flavour }}
|
||||
retention-days: 1
|
||||
|
||||
package-cortex:
|
||||
name: Package cortex RPM
|
||||
needs: [prepare, build-cortex]
|
||||
runs-on: rpm
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: cortex-fc43
|
||||
path: artifacts/
|
||||
|
||||
- name: Build RPM
|
||||
run: |
|
||||
set -eux
|
||||
rm -f ~/.rpmmacros
|
||||
rpmdev-setuptree
|
||||
cp artifacts/cortex ~/rpmbuild/SOURCES/
|
||||
cp data/cortex.service ~/rpmbuild/SOURCES/
|
||||
cp data/cortex-sysusers.conf ~/rpmbuild/SOURCES/
|
||||
cp data/cortex-firewalld.xml ~/rpmbuild/SOURCES/
|
||||
cp cortex.example.toml ~/rpmbuild/SOURCES/
|
||||
cp models.example.toml ~/rpmbuild/SOURCES/
|
||||
cp LICENSE ~/rpmbuild/SOURCES/
|
||||
rpmbuild -bb rpm/cortex-prerelease.spec \
|
||||
--define "cortex_version ${{ needs.prepare.outputs.version }}" \
|
||||
--define "cortex_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||
--undefine dist \
|
||||
--define "dist .fc43"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: rpm-cortex-fc43
|
||||
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||
retention-days: 7
|
||||
|
||||
package-neuron:
|
||||
name: Package helexa-neuron-${{ matrix.flavour }} RPM
|
||||
needs: [prepare, build-neuron]
|
||||
runs-on: rpm
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- flavour: ampere
|
||||
- flavour: ada
|
||||
- flavour: blackwell
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: neuron-${{ matrix.flavour }}-fc43
|
||||
path: artifacts/
|
||||
|
||||
- name: Build RPM
|
||||
run: |
|
||||
set -eux
|
||||
rm -f ~/.rpmmacros
|
||||
rpmdev-setuptree
|
||||
cp artifacts/neuron-${{ matrix.flavour }} ~/rpmbuild/SOURCES/
|
||||
cp data/neuron.service ~/rpmbuild/SOURCES/
|
||||
cp data/neuron-sysusers.conf ~/rpmbuild/SOURCES/
|
||||
cp data/neuron-firewalld.xml ~/rpmbuild/SOURCES/
|
||||
cp neuron.example.toml ~/rpmbuild/SOURCES/
|
||||
cp LICENSE ~/rpmbuild/SOURCES/
|
||||
rpmbuild -bb rpm/helexa-neuron-prerelease.spec \
|
||||
--define "neuron_version ${{ needs.prepare.outputs.version }}" \
|
||||
--define "neuron_flavour ${{ matrix.flavour }}" \
|
||||
--define "neuron_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||
--undefine dist \
|
||||
--define "dist .fc43"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: rpm-neuron-${{ matrix.flavour }}-fc43
|
||||
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||
retention-days: 7
|
||||
|
||||
publish:
|
||||
name: Publish to rpm.lair.cafe (unstable)
|
||||
needs: [package-cortex, package-neuron]
|
||||
runs-on: rpm
|
||||
concurrency:
|
||||
group: rpm-publish
|
||||
cancel-in-progress: false
|
||||
env:
|
||||
RPM_REPO_HOST: oolon.kosherinata.internal
|
||||
FEDORA_VERSION: "43"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- name: Download all built RPMs
|
||||
uses: actions/download-artifact@v3
|
||||
with:
|
||||
path: rpms/
|
||||
pattern: rpm-*-fc43
|
||||
|
||||
- name: Flatten RPM artifacts
|
||||
run: |
|
||||
set -eux
|
||||
find rpms/ -name '*.rpm' -exec mv --target-directory=rpms/ {} +
|
||||
find rpms/ -mindepth 1 -type d -empty -delete
|
||||
ls -la rpms/
|
||||
|
||||
- name: Check for sequoia-sq
|
||||
run: |
|
||||
if ! command -v sq &> /dev/null; then
|
||||
echo "ERROR: sequoia-sq is not installed. Install with: sudo dnf install sequoia-sq"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Import signing key
|
||||
env:
|
||||
# Pass secrets via env so values stay out of the rendered shell
|
||||
# script (which Gitea includes in step logs). Template
|
||||
# expansion of ${{ secrets.X }} inside `run:` writes the literal
|
||||
# value into the script and depends on Gitea's log masker to
|
||||
# scrub it — fragile for multi-line keys.
|
||||
RPM_SIGNING_KEY: ${{ secrets.RPM_SIGNING_KEY }}
|
||||
RPM_SIGNING_KEY_ID: ${{ secrets.RPM_SIGNING_KEY_ID }}
|
||||
run: |
|
||||
echo "$RPM_SIGNING_KEY" | gpg --batch --import
|
||||
fpr=$(gpg --batch --with-colons --list-keys "$RPM_SIGNING_KEY_ID" | awk -F: '/^fpr:/ { print $10; exit }')
|
||||
echo "${fpr}:6:" | gpg --batch --import-ownertrust
|
||||
sed "s/@GPG_NAME@/$RPM_SIGNING_KEY_ID/" rpm/rpmmacros > ~/.rpmmacros
|
||||
|
||||
- name: Sign RPMs
|
||||
run: |
|
||||
set -eux
|
||||
for rpm in rpms/*.rpm; do
|
||||
echo "signing ${rpm}..."
|
||||
rpm --addsign "${rpm}"
|
||||
done
|
||||
|
||||
- name: Set up SSH for rsync
|
||||
run: |
|
||||
install --directory --mode 700 ~/.ssh
|
||||
echo "${RSYNC_SSH_KEY}" | install --mode 600 /dev/stdin ~/.ssh/id_ed25519
|
||||
env:
|
||||
RSYNC_SSH_KEY: ${{ secrets.RSYNC_SSH_KEY }}
|
||||
|
||||
- name: Test SSH connectivity
|
||||
run: |
|
||||
ssh -o StrictHostKeyChecking=accept-new "gitea_ci@${RPM_REPO_HOST}" exit
|
||||
|
||||
- name: Ensure unstable repo directory exists
|
||||
run: |
|
||||
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||
"mkdir --parents /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||
|
||||
- name: Sync RPMs to unstable repo
|
||||
run: |
|
||||
rsync \
|
||||
--archive \
|
||||
--verbose \
|
||||
--chmod D755,F644 \
|
||||
rpms/*.rpm \
|
||||
"gitea_ci@${RPM_REPO_HOST}:/var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/"
|
||||
|
||||
- name: Update unstable repo metadata
|
||||
run: |
|
||||
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||
"cd /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable && createrepo_c --update ."
|
||||
|
||||
- name: Generate packages.json manifest
|
||||
run: |
|
||||
scp script/generate-packages-json.py "gitea_ci@${RPM_REPO_HOST}:/tmp/"
|
||||
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||
"python3 /tmp/generate-packages-json.py \
|
||||
--repodata-dir /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/repodata \
|
||||
--output /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/packages.json \
|
||||
--base-url https://rpm.lair.cafe/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||
@@ -7,6 +7,16 @@ on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
# Share a concurrency group with build-prerelease.yml so the two
|
||||
# workflows don't race on the same `rust` runner workspace (act's
|
||||
# /root/.cache/act/<hash>/hostexecutor/ is shared across concurrent
|
||||
# jobs and one job's checkout step nukes another's in-flight build
|
||||
# files). cancel-in-progress=false → they queue; same-ref pushes
|
||||
# coalesce per workflow via cancel-in-progress on each.
|
||||
concurrency:
|
||||
group: cortex-runner-pool-${{ github.ref }}
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
CARGO_INCREMENTAL: "0"
|
||||
RUSTC_WRAPPER: sccache
|
||||
@@ -16,40 +26,134 @@ env:
|
||||
SCCACHE_S3_USE_SSL: "false"
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||
# fmt, clippy, and test all run in parallel on the same `rust` runner
|
||||
# and would otherwise share /root/.cache/act/<hash>/hostexecutor/target/,
|
||||
# racing each other's cargo temp files (.tmpXXXXXX) and failing builds
|
||||
# mid-compile. Give each job its own target directory so the invocations
|
||||
# don't collide. sccache still backs the actual rustc cache, so the
|
||||
# rebuild penalty is small.
|
||||
CARGO_TARGET_DIR: target-${{ github.job }}
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Format, lint, build, test
|
||||
runs-on: fedora
|
||||
fmt:
|
||||
name: Format
|
||||
runs-on: rust
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: cargo fmt --check --all
|
||||
|
||||
- name: Ensure sccache with S3 support
|
||||
env:
|
||||
RUSTC_WRAPPER: ""
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: rust
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
# sccache occasionally fails with spurious race-condition errors;
|
||||
# retrying the same invocation succeeds without code changes.
|
||||
# Allow up to 3 attempts before declaring real failure.
|
||||
- name: Clippy (with retry)
|
||||
run: |
|
||||
if sccache --version 2>/dev/null && sccache --show-stats 2>/dev/null; then
|
||||
echo "sccache with S3 support already installed"
|
||||
else
|
||||
cargo install sccache --features s3 --locked
|
||||
fi
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::clippy attempt ${attempt}"
|
||||
if cargo clippy --workspace -- -D warnings; then
|
||||
echo "::endgroup::"
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "clippy failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -lt 3 ]; then
|
||||
sleep 5
|
||||
fi
|
||||
done
|
||||
echo "clippy failed after 3 attempts"
|
||||
exit 1
|
||||
- run: sccache --show-stats
|
||||
|
||||
- name: Check formatting
|
||||
run: cargo fmt --check --all
|
||||
test:
|
||||
name: Test
|
||||
runs-on: rust
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
# See the clippy job for why this is retried.
|
||||
- name: Test (with retry)
|
||||
run: |
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::test attempt ${attempt}"
|
||||
if cargo test --workspace; then
|
||||
echo "::endgroup::"
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "test failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -lt 3 ]; then
|
||||
sleep 5
|
||||
fi
|
||||
done
|
||||
echo "test failed after 3 attempts"
|
||||
exit 1
|
||||
- run: sccache --show-stats
|
||||
|
||||
- name: Clippy
|
||||
run: cargo clippy --workspace -- -D warnings
|
||||
|
||||
- name: Test
|
||||
run: cargo test --workspace
|
||||
|
||||
- name: Show sccache stats
|
||||
run: sccache --show-stats
|
||||
# Type-check the CUDA-only code path. Borrow-check-only — we
|
||||
# never run the tests here (the runner has no GPU). This catches
|
||||
# the category of bug where a refactor compiles fine under the
|
||||
# default feature set (which is what the `clippy` and `test` jobs
|
||||
# exercise) but fails inside a `#[cfg(feature = "cuda")]` block.
|
||||
# `runs-on: cuda-13.0` selects the runner that ships nvcc /
|
||||
# cudarc's build prerequisites. The generic `rust` and `rpm`
|
||||
# runners don't have them (the previous label `rpm` was tried
|
||||
# first and tripped cudarc's `nvcc --version` build script —
|
||||
# see commit history).
|
||||
cuda-check:
|
||||
name: CUDA type-check
|
||||
runs-on: cuda-13.0
|
||||
# The workflow-level env sets `RUSTC_WRAPPER: sccache` for the
|
||||
# `rust` runner (where fmt/clippy/test live and sccache is
|
||||
# installed). The `cuda-13.0` runner doesn't have sccache on
|
||||
# PATH, so inheriting the wrapper makes cargo bail with
|
||||
# `could not execute process `sccache rustc -vV` (never executed)`
|
||||
# before borrow-check even starts. Clear it locally. Also clear
|
||||
# SCCACHE_* so cargo doesn't try to contact the cache (the
|
||||
# remote auth headers come from secrets that aren't present on
|
||||
# this runner either). Lose the cache, keep the gate.
|
||||
env:
|
||||
RUSTC_WRAPPER: ""
|
||||
SCCACHE_BUCKET: ""
|
||||
SCCACHE_ENDPOINT: ""
|
||||
SCCACHE_REGION: ""
|
||||
SCCACHE_S3_USE_SSL: ""
|
||||
AWS_ACCESS_KEY_ID: ""
|
||||
AWS_SECRET_ACCESS_KEY: ""
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: cargo check --features cuda (with retry)
|
||||
run: |
|
||||
# act launches the step shell without /etc/profile, so the
|
||||
# gitea_runner user's inherited PATH lacks /usr/local/cuda-13.0/bin.
|
||||
# cudarc's build.rs:157 shells out to `nvcc --version` (because
|
||||
# the neuron crate enables cuda-version-from-build-system) and
|
||||
# panics with ENOENT if nvcc isn't resolvable. build-prerelease.yml
|
||||
# does the same export — keep them in sync.
|
||||
export PATH="/usr/local/cuda-13.0/bin:${PATH}"
|
||||
export LD_LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LD_LIBRARY_PATH:-}"
|
||||
export LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LIBRARY_PATH:-}"
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::cuda-check attempt ${attempt}"
|
||||
if cargo check -p neuron --features cuda --all-targets; then
|
||||
echo "::endgroup::"
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "cuda-check failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -lt 3 ]; then
|
||||
sleep 5
|
||||
fi
|
||||
done
|
||||
echo "cuda-check failed after 3 attempts"
|
||||
exit 1
|
||||
|
||||
srpm-cortex:
|
||||
name: Build cortex SRPM
|
||||
runs-on: fedora
|
||||
needs: check
|
||||
runs-on: rpm
|
||||
needs: [fmt, clippy, test, cuda-check]
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -108,8 +212,8 @@ jobs:
|
||||
|
||||
srpm-neuron:
|
||||
name: Build neuron SRPM
|
||||
runs-on: fedora
|
||||
needs: check
|
||||
runs-on: rpm
|
||||
needs: [fmt, clippy, test, cuda-check]
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -168,7 +272,7 @@ jobs:
|
||||
|
||||
copr-cortex:
|
||||
name: Publish cortex to COPR
|
||||
runs-on: fedora
|
||||
runs-on: fedora-43
|
||||
needs: srpm-cortex
|
||||
steps:
|
||||
- name: Download SRPM
|
||||
@@ -185,7 +289,7 @@ jobs:
|
||||
|
||||
copr-neuron:
|
||||
name: Publish neuron to COPR
|
||||
runs-on: fedora
|
||||
runs-on: fedora-43
|
||||
needs: srpm-neuron
|
||||
steps:
|
||||
- name: Download SRPM
|
||||
@@ -202,7 +306,7 @@ jobs:
|
||||
|
||||
bump-version:
|
||||
name: Bump version in source
|
||||
runs-on: fedora
|
||||
runs-on: rust
|
||||
needs: [copr-cortex, copr-neuron]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,5 +4,6 @@
|
||||
.idea/
|
||||
.vscode/
|
||||
cortex.toml
|
||||
models.toml
|
||||
doc/plan/*
|
||||
script/deploy.sh
|
||||
/target-cuda/
|
||||
|
||||
177
CLAUDE.md
177
CLAUDE.md
@@ -84,6 +84,63 @@ Per-request: model, node, prompt_tokens, completion_tokens, total_tokens,
|
||||
tok_per_sec, time_to_first_token_ms, total_latency_ms.
|
||||
Exposed as Prometheus histograms/counters on a separate port.
|
||||
|
||||
### Per-device worker thread (neuron)
|
||||
The neuron daemon dedicates one OS thread per CUDA device it loads
|
||||
onto. That thread binds the device's `CudaContext` once at startup and
|
||||
owns it for the daemon's lifetime; every model load, forward step,
|
||||
KV-cache reset, VRAM query, NCCL init/sanity, NCCL all_reduce, and
|
||||
model drop on that device routes through this thread via a
|
||||
`std::sync::mpsc` job channel. Replies cross back via
|
||||
`tokio::sync::oneshot`.
|
||||
|
||||
Three properties this gives us, in order of weight:
|
||||
|
||||
1. **Context locality.** cudarc binds the CUDA context per OS thread
|
||||
via `cuCtxSetCurrent`. Before this refactor, ad-hoc
|
||||
`tokio::task::spawn_blocking` calls bound the context onto a
|
||||
different thread per request — and `device_vram_mb()` from an
|
||||
async task bound it onto whichever tokio worker happened to be
|
||||
running. Pinning the context to one named thread ends that.
|
||||
2. **Drop safety.** Every `CudaSlice` in a `Tensor`, every
|
||||
`cudarc::nccl::Comm`, and the `CudaContext` itself call `cuMemFree` /
|
||||
`ncclCommDestroy` / `cuCtxDestroy` during `Drop` — and require the
|
||||
right context current. With the worker owning the model slab,
|
||||
`Drop` always runs on the right thread. The cudarc Drop constraint
|
||||
is structurally enforced.
|
||||
3. **Poisoning blast radius.** When a CUDA driver error makes the
|
||||
context unrecoverable, the poison flag lives on the
|
||||
`DeviceWorkerHandle` itself. Subsequent `submit()` calls fast-reject
|
||||
at the channel boundary with a clear "device worker is poisoned"
|
||||
error before any further CUDA work is attempted. The thread doesn't
|
||||
exit (dropping the slab would re-touch the broken context) — it
|
||||
enters a drain-only mode and replies error to everything until the
|
||||
daemon restarts.
|
||||
|
||||
Tensors never escape the worker thread alive. Inference replies carry
|
||||
`Vec<f32>` CPU-side logits; the async caller wraps them in a CPU
|
||||
candle tensor and runs `apply_repeat_penalty` + `LogitsProcessor::sample`
|
||||
without ever rebinding the device context. Sampled tokens come back as
|
||||
`u32`; VRAM queries as `(u64, u64)`. The opaque `ArchHandle(u64)` and
|
||||
`TpHandle(u64)` are the only "references" callers hold to loaded
|
||||
models — they're indices into the worker's state slab, not pointers.
|
||||
|
||||
The TP worker subprocesses in `harness/tp/worker.rs` are the same
|
||||
pattern out-of-process — a dedicated context-owning process per
|
||||
non-zero NCCL rank. The in-process worker in `harness/device_worker/`
|
||||
brings the discipline to rank 0.
|
||||
|
||||
CPU loads (`Device::Cpu` fallback when CUDA is unavailable) keep the
|
||||
legacy `tokio::task::spawn_blocking + Arc<Mutex<ModelArch>>` path —
|
||||
there's no context to own and the channel hop would only add latency.
|
||||
Four `spawn_blocking` references in `harness/candle.rs` are deliberate
|
||||
CPU fallback.
|
||||
|
||||
Canonical narrative lives in
|
||||
`crates/neuron/src/harness/device_worker/mod.rs`'s module
|
||||
doc-comment; touch points (the `Job` enum, the dispatch handlers, the
|
||||
`DeviceWorkerState` struct) are in the sibling `jobs.rs` and
|
||||
`dispatch.rs`.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- **Rust 2024 edition** — workspace with 4 crates
|
||||
@@ -616,58 +673,84 @@ dnf install cortex # gateway host
|
||||
dnf install helexa-neuron # GPU nodes
|
||||
```
|
||||
|
||||
### Phase 11: llama.cpp harness stub
|
||||
## 2026-05-18 addendum: candle-native pivot
|
||||
|
||||
**Goal:** Prove the harness abstraction works with a second engine.
|
||||
Phases 11 (llama.cpp harness) and 12 (mistral.rs COPR) below are
|
||||
**superseded**. The project no longer treats mistral.rs or llama.cpp as
|
||||
dependencies — both are conceptually out of scope. neuron becomes a
|
||||
candle-native inference daemon, with `Harness` retained as an
|
||||
internal seam for adding future engines (vision/audio/diffusion) but
|
||||
its only implementation being in-process candle.
|
||||
|
||||
**Steps:**
|
||||
1. `crates/neuron/src/harness/llamacpp.rs` — implement the `Harness`
|
||||
trait for llama.cpp's `llama-server`.
|
||||
- `start()` — launch `llama-server` with the correct model path,
|
||||
`--port`, `--n-gpu-layers`, `--tensor-split` args. Track the
|
||||
child process.
|
||||
- `stop()` — send SIGTERM to the child process.
|
||||
- `list_models()` — llama-server serves one model per process, so
|
||||
return a single-element list.
|
||||
- `load_model()` — start a new llama-server process for this model.
|
||||
- `unload_model()` — stop the process.
|
||||
- `inference_endpoint()` — return `http://localhost:{assigned_port}`.
|
||||
2. Port allocation: neuron assigns ports from a range (e.g. 8100-8199)
|
||||
to llama-server instances.
|
||||
3. Register in `HarnessRegistry` when configured:
|
||||
```toml
|
||||
[[harnesses]]
|
||||
name = "llamacpp"
|
||||
binary = "/usr/local/bin/llama-server"
|
||||
port_range = [8100, 8199]
|
||||
```
|
||||
4. Tests: mock llama-server (simple HTTP server returning canned
|
||||
responses), test load/unload/endpoint lifecycle.
|
||||
The full staged plan for this pivot lives at
|
||||
`~/.claude/plans/create-a-more-aggressive-calm-naur.md`. Summary:
|
||||
|
||||
**Done when:** A model with `harness = "llamacpp"` in `models.toml` can
|
||||
be loaded and served through cortex. Tests pass with mock llama-server.
|
||||
- **Stage 1 (this commit):** delete `mistralrs.rs` and `llamacpp.rs`,
|
||||
scaffold inert `CandleHarness`, drop `endpoint`/`systemd_unit` from
|
||||
`HarnessConfig`, default no-op `start`/`stop` on the `Harness` trait.
|
||||
- **Stages 2–4:** wire up candle model load/unload (quantized Qwen3
|
||||
first), add OpenAI-compatible inference endpoint in neuron, then SSE
|
||||
streaming.
|
||||
- **Stages 5–6:** load-on-activation (default models in config) and
|
||||
unload-on-deactivation (graceful shutdown).
|
||||
- **Stages 7–8:** multi-GPU tensor parallelism and broader model/quant
|
||||
coverage.
|
||||
|
||||
### Phase 12 (lower priority): mistral.rs COPR packaging
|
||||
Sections of this document that describe mistral.rs HTTP behaviour
|
||||
("mistral.rs API gotchas") are retained as historical context for
|
||||
Phases 1–10 — they document what was true while the project depended
|
||||
on mistral.rs. They do not describe current behaviour.
|
||||
|
||||
**Goal:** Fedora RPMs for mistral.rs built against specific CUDA versions.
|
||||
---
|
||||
|
||||
**Steps:**
|
||||
1. `mistralrs-cuda.spec` — RPM spec that clones a pinned mistral.rs git
|
||||
tag, builds with `--features cuda`, links against the system CUDA
|
||||
toolkit. Produces `mistralrs-cuda13-server` (CUDA 13.x / sm_120) and
|
||||
`mistralrs-cuda12-server` (CUDA 12.x / sm_89). Install binary to
|
||||
`/usr/local/bin/mistralrs`.
|
||||
2. COPR build config: enable the NVIDIA CUDA repo as a build dependency.
|
||||
Pin the CUDA toolkit version in `BuildRequires`.
|
||||
3. Gitea Actions or manual workflow: bump the mistral.rs tag in the spec,
|
||||
trigger COPR rebuild.
|
||||
4. neuron's mistralrs harness config references which binary/package
|
||||
provides the mistral.rs binary. neuron could warn at startup if the
|
||||
installed mistral.rs CUDA version doesn't match the discovered driver.
|
||||
### Phase 11 (superseded): llama.cpp harness stub
|
||||
|
||||
**Done when:** `dnf install mistralrs-cuda13-server` on beast provides a
|
||||
working `mistralrs` binary built for Blackwell GPUs. `dnf install
|
||||
mistralrs-cuda12-server` on benjy provides one built for Ada GPUs.
|
||||
~~Originally planned as a second engine to prove the harness
|
||||
abstraction.~~ Replaced by the candle harness work in the 2026-05-18
|
||||
addendum above. llama.cpp's any-model/any-hardware breadth is no
|
||||
longer in scope for helexa.
|
||||
|
||||
This is a separate repo/spec — not part of the cortex workspace — but
|
||||
tightly coupled operationally. Track it as a sibling project.
|
||||
### Phase 12 (superseded): mistral.rs COPR packaging
|
||||
|
||||
~~Originally planned to ship CUDA-versioned mistral.rs RPMs.~~ Replaced
|
||||
by the candle harness work in the 2026-05-18 addendum above. With
|
||||
mistral.rs out of the dependency tree, there is nothing to package.
|
||||
|
||||
## 2026-05-27 addendum: per-device worker thread
|
||||
|
||||
Replaced the ad-hoc `tokio::task::spawn_blocking` pattern that drove
|
||||
every leader-side CUDA op with one dedicated OS thread per CUDA device,
|
||||
permanently bound to that device's `CudaContext`. All leader-side
|
||||
inference work (GGUF + dense + TP shard load, forward, kv-cache clear,
|
||||
NCCL init/sanity, NCCL all_reduce, VRAM query, model drop) routes
|
||||
through the worker via a `std::sync::mpsc` channel; tensors never
|
||||
escape the worker thread alive. See "Per-device worker thread (neuron)"
|
||||
above and `crates/neuron/src/harness/device_worker/mod.rs` for the
|
||||
canonical narrative.
|
||||
|
||||
Motivated by the 2026-05-26 silent-hang on beast: a CUDA OOM cascade
|
||||
poisoned the device context on whichever spawn_blocking thread caught
|
||||
it, and subsequent requests stalled invisibly on the pool lock. After
|
||||
the refactor, the same failure mode shows up in journalctl as
|
||||
`prefill sample failed; logits unhealthy nan: 248320/248320` followed
|
||||
by `failed, model marked poisoned`. The thread stays alive and rejects
|
||||
subsequent requests at the channel boundary.
|
||||
|
||||
Landed in four PRs:
|
||||
|
||||
- **Phase 1** (`081b532`) — device_worker module + 8 VRAM-query sites
|
||||
route through the worker. CPU build only; smoke on beast confirmed
|
||||
a persistent `cuda-dev-0` thread.
|
||||
- **Phase 2** (`b179204`) — single-GPU forward + clear_kv + drop via
|
||||
the worker. `LoadedModel.arch_handle: Option<ArchHandle>` replaces
|
||||
`Arc<Mutex<ModelArch>>` for CUDA loads. CPU keeps the legacy path.
|
||||
- **Phase 3** (`76ab24d`) — TP forward + NCCL init/sanity + leader
|
||||
KV-clear routed through the worker. `WorkerPool.leader_nccl` moves
|
||||
into the worker's state. `TpLoadedModel.leader_handle: TpHandle`
|
||||
replaces `Arc<Mutex<TpLeaderModel>>`. CUDA-only TP smoke deferred to
|
||||
next deploy.
|
||||
- **Phase 4** (`b4f3576`) — GGUF + dense + TP shard loads move onto
|
||||
the worker. The `Job::TransferIn` / `Job::CloneLeaderComm` bridges
|
||||
from Phases 2/3 deleted; `SendComm` newtype no longer needed in the
|
||||
load path. `grep -rn spawn_blocking crates/neuron/src/harness/`
|
||||
returns only deliberate CPU-fallback hits after this PR.
|
||||
|
||||
2420
Cargo.lock
generated
2420
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -5,10 +5,11 @@ members = [
|
||||
"crates/cortex-gateway",
|
||||
"crates/cortex-cli",
|
||||
"crates/neuron",
|
||||
"crates/helexa-acp",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.14"
|
||||
version = "0.1.16"
|
||||
edition = "2024"
|
||||
license = "GPL-3.0-or-later"
|
||||
repository = "https://git.lair.cafe/helexa/cortex"
|
||||
@@ -27,7 +28,7 @@ serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
toml = "0.8"
|
||||
|
||||
# http client (for proxying to mistralrs backends)
|
||||
# http client (for proxying to neuron backends)
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
|
||||
# observability
|
||||
|
||||
107
README.md
107
README.md
@@ -1,22 +1,23 @@
|
||||
# cortex
|
||||
|
||||
A Rust reverse-proxy and fleet management layer for multi-node
|
||||
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
|
||||
A Rust reverse-proxy and fleet management layer for multi-node GPU inference
|
||||
clusters. Cortex sits in front of one or more `neuron` daemons (each running
|
||||
candle-based inference on a local GPU host) and presents a unified OpenAI +
|
||||
Anthropic compatible API surface.
|
||||
|
||||
## Problem
|
||||
|
||||
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
||||
model affinities) requires a unified API surface that:
|
||||
|
||||
- Presents a **single `/v1/models` catalogue** merging every model across every
|
||||
node.
|
||||
- **Routes requests** to the correct node based on where a model is loaded (or
|
||||
*can* be loaded).
|
||||
- Manages **model lifecycle** — unload cold models, reload on demand, pin
|
||||
critical ones — using the mistral.rs
|
||||
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
|
||||
- Presents a **single `/v1/models` catalogue** merging every model that can be
|
||||
served by any neuron in the fleet.
|
||||
- **Routes requests** to the correct node based on where a model is loaded
|
||||
(or can be loaded), handling cold-load and eviction transparently.
|
||||
- Manages **model lifecycle** — load on demand, unload cold models, pin
|
||||
critical ones — by calling each neuron's `/models/{load,unload}` API.
|
||||
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
||||
every client in the homelab speaks whichever dialect it prefers.
|
||||
every client speaks whichever dialect it prefers.
|
||||
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
||||
them as Prometheus counters/histograms.
|
||||
|
||||
@@ -30,18 +31,17 @@ model affinities) requires a unified API surface that:
|
||||
└────────────────┴──────┬───────┴───────────────┘
|
||||
│
|
||||
┌──────────▼──────────┐
|
||||
│ cortex │
|
||||
│ (cortex-gateway) │
|
||||
│ cortex │
|
||||
│ (cortex-gateway) │
|
||||
│ │
|
||||
│ Router · Metrics │
|
||||
│ Evictor · Translate│
|
||||
└──┬──────┬────────┬──┘
|
||||
│ │ │
|
||||
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
||||
│ gpu-large │ │gpu-med │ │ gpu-small │
|
||||
│ mistralrs │ │mistral │ │ mistralrs │
|
||||
│ serve │ │rs serve│ │ serve │
|
||||
│ :8080 │ │ :8080 │ │ :8080 │
|
||||
│ neuron │ │ neuron │ │ neuron │
|
||||
│ :13131 │ │ :13131 │ │ :13131 │
|
||||
│ candle │ │ candle │ │ candle │
|
||||
└───────────┘ └────────┘ └───────────┘
|
||||
private network (.internal)
|
||||
```
|
||||
@@ -50,43 +50,39 @@ model affinities) requires a unified API surface that:
|
||||
|
||||
| Crate | Purpose |
|
||||
|---|---|
|
||||
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
|
||||
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
|
||||
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
|
||||
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic envelopes, harness trait, discovery types |
|
||||
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, poller, metrics exporter |
|
||||
| `neuron` | Per-node daemon: GPU discovery, in-process candle inference, model lifecycle API |
|
||||
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
||||
|
||||
## Node setup
|
||||
|
||||
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
|
||||
declared but start **unloaded** — mistral.rs lazy-loads on first request and
|
||||
the gateway can explicitly unload/reload via the HTTP API.
|
||||
Each GPU node runs `neuron` (listening on `:13131`). Neuron uses
|
||||
huggingface/candle for in-process inference — there is no external
|
||||
inference subprocess to manage.
|
||||
|
||||
Example node systemd unit:
|
||||
Inside the daemon, every CUDA device gets one dedicated OS thread
|
||||
(named `cuda-dev-N`) that owns the device's CUDA context for the
|
||||
daemon's lifetime. Model loads, forward passes, KV-cache resets,
|
||||
NCCL collectives, VRAM queries, and unloads all route through that
|
||||
thread via a job channel; tensors never escape it alive. This pins
|
||||
context binding to a known thread, makes the CUDA Drop contract
|
||||
structurally safe, and isolates driver-error poisoning to one worker
|
||||
rather than the whole process. See `CLAUDE.md` for the design
|
||||
rationale and `crates/neuron/src/harness/device_worker/` for the code.
|
||||
|
||||
```ini
|
||||
# /etc/systemd/system/mistralrs.service
|
||||
[Unit]
|
||||
Description=mistral.rs inference server
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
The neuron RPM (`helexa-neuron`) ships a systemd unit:
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=/usr/local/bin/mistralrs serve \
|
||||
--from-config /etc/mistralrs/config.toml \
|
||||
--port 8080
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
Environment=CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```sh
|
||||
dnf copr enable helexa/helexa
|
||||
dnf install helexa-neuron
|
||||
systemctl enable --now neuron
|
||||
```
|
||||
|
||||
## Gateway config
|
||||
|
||||
```toml
|
||||
# cortex.toml
|
||||
# /etc/cortex/cortex.toml
|
||||
[gateway]
|
||||
listen = "0.0.0.0:31313"
|
||||
metrics_listen = "0.0.0.0:31314"
|
||||
@@ -95,25 +91,17 @@ metrics_listen = "0.0.0.0:31314"
|
||||
strategy = "lru" # lru | priority
|
||||
defrag_after_cycles = 50
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-large"
|
||||
endpoint = "http://gpu-large.internal:8080"
|
||||
vram_mb = 49_152 # e.g. 2x RTX 4090
|
||||
pinned = ["your-org/large-model"]
|
||||
[[neurons]]
|
||||
name = "beast"
|
||||
endpoint = "http://beast.internal:13131"
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-medium"
|
||||
endpoint = "http://gpu-medium.internal:8080"
|
||||
vram_mb = 24_576 # e.g. RTX 4090
|
||||
pinned = ["your-org/medium-model"]
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-small"
|
||||
endpoint = "http://gpu-small.internal:8080"
|
||||
vram_mb = 12_288 # e.g. RTX 3060
|
||||
pinned = ["your-org/embedding-model"]
|
||||
[[neurons]]
|
||||
name = "benjy"
|
||||
endpoint = "http://benjy.internal:13131"
|
||||
```
|
||||
|
||||
Model placement profiles live in `models.toml` — see `models.example.toml`.
|
||||
|
||||
## Building
|
||||
|
||||
```sh
|
||||
@@ -131,13 +119,14 @@ cargo clippy --workspace -- -D warnings # warnings are errors
|
||||
cargo test --workspace # all tests must pass
|
||||
```
|
||||
|
||||
Tagged releases (`v*`) additionally build an SRPM and publish to COPR.
|
||||
Tagged releases (`v*`) additionally build SRPMs for both `cortex` and
|
||||
`helexa-neuron` and publish to COPR.
|
||||
|
||||
## Running
|
||||
|
||||
```sh
|
||||
# start the gateway
|
||||
cortex serve --config cortex.toml
|
||||
cortex serve --config /etc/cortex/cortex.toml
|
||||
|
||||
# check fleet status
|
||||
cortex status
|
||||
|
||||
30
asset/manifest.yml
Normal file
30
asset/manifest.yml
Normal file
@@ -0,0 +1,30 @@
|
||||
# Helexa fleet manifest.
|
||||
#
|
||||
# Drives rolling deploys via script/deploy.sh and serves as the source
|
||||
# of truth for which hosts run cortex vs neuron, and which CUDA
|
||||
# compute-capability flavour each neuron host needs.
|
||||
#
|
||||
# Flavour ↔ NVIDIA generation ↔ compute cap:
|
||||
# ampere sm_86 (RTX 30 series — e.g. 3060)
|
||||
# ada sm_89 (RTX 40 series — e.g. 4090)
|
||||
# blackwell sm_120 (RTX 50 series — e.g. 5090)
|
||||
#
|
||||
# The flavour determines which RPM is installed on a given neuron host:
|
||||
# helexa-neuron-<flavour>. Only one flavour may be installed at a time
|
||||
# (the packages Conflict: with each other).
|
||||
|
||||
cortex:
|
||||
host: hanzalova.internal
|
||||
|
||||
neurons:
|
||||
- host: beast.hanzalova.internal
|
||||
flavour: blackwell
|
||||
gpu: "2x RTX 5090"
|
||||
|
||||
- host: benjy.hanzalova.internal
|
||||
flavour: ada
|
||||
gpu: "RTX 4090"
|
||||
|
||||
- host: quadbrat.hanzalova.internal
|
||||
flavour: ampere
|
||||
gpu: "RTX 3060"
|
||||
24
asset/neuron/beast.toml
Normal file
24
asset/neuron/beast.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
# neuron.toml for beast.hanzalova.internal
|
||||
#
|
||||
# 2x RTX 5090 (32 GB each) — TP-2 capable. Pre-warms Qwen3.6-27B with
|
||||
# q5k ISQ across both GPUs at activation, matching the validate-neuron
|
||||
# invocation: `validate-neuron.sh beast.hanzalova.internal
|
||||
# Qwen/Qwen3.6-27B q5k 2`.
|
||||
#
|
||||
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml. Edits
|
||||
# take effect on the next deploy.sh run (which stops + restarts the
|
||||
# service so default_models is re-read at activation).
|
||||
|
||||
port = 13131
|
||||
|
||||
[[harnesses]]
|
||||
name = "candle"
|
||||
|
||||
[harness.candle]
|
||||
|
||||
[[default_models]]
|
||||
model_id = "Qwen/Qwen3.6-27B"
|
||||
harness = "candle"
|
||||
quant = "q6k"
|
||||
tensor_parallel = 2
|
||||
devices = [0, 1]
|
||||
19
asset/neuron/benjy.toml
Normal file
19
asset/neuron/benjy.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
# neuron.toml for benjy.hanzalova.internal
|
||||
#
|
||||
# 1x RTX 4090 (24 GB) — largest single-GPU host on the fleet. Pre-warms
|
||||
# Qwen3-8B (bf16, ~18 GB), leaving ~6 GB for KV cache + activations on
|
||||
# moderate-length contexts.
|
||||
#
|
||||
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml.
|
||||
|
||||
port = 13131
|
||||
|
||||
[[harnesses]]
|
||||
name = "candle"
|
||||
|
||||
[harness.candle]
|
||||
|
||||
[[default_models]]
|
||||
model_id = "Qwen/Qwen3-8B"
|
||||
harness = "candle"
|
||||
devices = [0]
|
||||
19
asset/neuron/quadbrat.toml
Normal file
19
asset/neuron/quadbrat.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
# neuron.toml for quadbrat.hanzalova.internal
|
||||
#
|
||||
# 1x RTX 3060 (12 GB) — small / quantised tier. Pre-warms Qwen3-1.7B
|
||||
# (bf16, ~4 GB), leaving ~7 GB for KV cache so long contexts on a small
|
||||
# model still have plenty of room.
|
||||
#
|
||||
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml.
|
||||
|
||||
port = 13131
|
||||
|
||||
[[harnesses]]
|
||||
name = "candle"
|
||||
|
||||
[harness.candle]
|
||||
|
||||
[[default_models]]
|
||||
model_id = "Qwen/Qwen3-1.7B"
|
||||
harness = "candle"
|
||||
devices = [0]
|
||||
@@ -11,14 +11,14 @@ metrics_listen = "0.0.0.0:31314"
|
||||
|
||||
[eviction]
|
||||
strategy = "lru"
|
||||
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
|
||||
# Restart neurons after this many load/unload cycles to defragment VRAM.
|
||||
# Set to 0 to disable.
|
||||
defrag_after_cycles = 50
|
||||
|
||||
# -- Nodes ---------------------------------------------------------------
|
||||
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
|
||||
# Models are discovered by polling the node's /v1/models endpoint.
|
||||
# Pinned models are never evicted.
|
||||
# Each [[nodes]] entry declares a neuron daemon in the fleet.
|
||||
# Models are discovered by polling the neuron's /models endpoint.
|
||||
# Pinned models (see models.toml) are never evicted.
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-large"
|
||||
|
||||
36
cortex.spec
36
cortex.spec
@@ -1,5 +1,5 @@
|
||||
Name: cortex
|
||||
Version: 0.1.14
|
||||
Version: 0.1.16
|
||||
Release: 1%{?dist}
|
||||
Summary: Inference gateway for multi-node GPU clusters
|
||||
|
||||
@@ -21,6 +21,7 @@ BuildRequires: systemd-rpm-macros
|
||||
|
||||
Requires(pre): shadow-utils
|
||||
Requires: systemd
|
||||
Requires: firewalld-filesystem
|
||||
|
||||
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
|
||||
# from our .service file and emits Requires: user(cortex)/group(cortex).
|
||||
@@ -56,6 +57,7 @@ cargo build --release -p cortex-cli
|
||||
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
|
||||
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
|
||||
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
|
||||
install -Dm644 data/cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
|
||||
install -dm755 %{buildroot}%{_sysconfdir}/cortex
|
||||
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
|
||||
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
||||
@@ -72,17 +74,49 @@ install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
||||
%postun
|
||||
%systemd_postun_with_restart cortex.service
|
||||
|
||||
%posttrans
|
||||
# Migration: older cortex packages shipped the firewalld service as
|
||||
# `helexa-cortex` and (in some build streams) with wrong port numbers
|
||||
# (9301/9302/9304). Operators who enabled that legacy service in their
|
||||
# zone end up with the wrong-port override taking precedence over the
|
||||
# vendor `cortex.xml` now in /usr/lib/firewalld/services/. Clean up the
|
||||
# stale /etc/ override here and migrate any zone bindings to the new
|
||||
# service name.
|
||||
if [ -f /etc/firewalld/services/helexa-cortex.xml ]; then
|
||||
rm -f /etc/firewalld/services/helexa-cortex.xml
|
||||
fi
|
||||
if [ -x /usr/bin/firewall-cmd ] && /usr/bin/firewall-cmd --state >/dev/null 2>&1; then
|
||||
# Drop the legacy service name from every zone where it was enabled
|
||||
# and add the new `cortex` service in its place. Operators who never
|
||||
# ran firewall-cmd against either name see no zone change.
|
||||
for zone in $(/usr/bin/firewall-cmd --get-active-zones 2>/dev/null \
|
||||
| awk '!/^[[:space:]]/ {print $1}'); do
|
||||
if /usr/bin/firewall-cmd --permanent --zone="$zone" --query-service=helexa-cortex >/dev/null 2>&1; then
|
||||
/usr/bin/firewall-cmd --permanent --zone="$zone" --remove-service=helexa-cortex >/dev/null 2>&1 || :
|
||||
/usr/bin/firewall-cmd --permanent --zone="$zone" --add-service=cortex >/dev/null 2>&1 || :
|
||||
fi
|
||||
done
|
||||
/usr/bin/firewall-cmd --reload >/dev/null 2>&1 || :
|
||||
fi
|
||||
:
|
||||
|
||||
%files
|
||||
%license LICENSE
|
||||
%doc README.md
|
||||
%{_bindir}/cortex
|
||||
%{_unitdir}/cortex.service
|
||||
%{_sysusersdir}/cortex.conf
|
||||
%{_prefix}/lib/firewalld/services/cortex.xml
|
||||
%dir %{_sysconfdir}/cortex
|
||||
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
|
||||
%config(noreplace) %{_sysconfdir}/cortex/models.toml
|
||||
|
||||
%changelog
|
||||
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
|
||||
- chore: ignore local deploy script
|
||||
- chore: move default ports out of common-collision ranges
|
||||
- ci: drop actions/cache for cargo registry and target
|
||||
|
||||
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
|
||||
- ci: publish both packages to a single helexa/helexa COPR project
|
||||
- fix(rpm): rename neuron package to helexa-neuron
|
||||
|
||||
@@ -5,7 +5,7 @@ use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "cortex")]
|
||||
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
|
||||
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
|
||||
#[command(version)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
//!
|
||||
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
||||
//! The gateway accepts these, translates to OpenAI format, proxies to
|
||||
//! mistral.rs, then translates the response back.
|
||||
//! the inference backend (neuron), then translates the response back.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
//! Model catalogue — profiles describing how to serve each model.
|
||||
|
||||
use crate::discovery::DeviceInfo;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
/// A model serving profile loaded from models.toml.
|
||||
@@ -33,6 +35,14 @@ fn default_min_devices() -> u32 {
|
||||
pub struct ModelCatalogue {
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelProfile>,
|
||||
/// Tier aliases — clients can send a request with `model: "helexa/small"`
|
||||
/// and the gateway transparently rewrites + routes to the concrete
|
||||
/// model id this maps to. Lets operators define latency/quality
|
||||
/// tiers (`small`/`balanced`/`large`, `fast`/`thinking`, etc.)
|
||||
/// without imposing knowledge of specific model ids on clients.
|
||||
/// Loaded from the `[aliases]` table in models.toml.
|
||||
#[serde(default)]
|
||||
pub aliases: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ModelCatalogue {
|
||||
@@ -64,4 +74,138 @@ impl ModelCatalogue {
|
||||
.iter()
|
||||
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
|
||||
}
|
||||
|
||||
/// Find a profile by model id.
|
||||
pub fn get(&self, model_id: &str) -> Option<&ModelProfile> {
|
||||
self.models.iter().find(|p| p.id == model_id)
|
||||
}
|
||||
|
||||
/// Resolve an alias to its concrete model id. Returns `id` verbatim
|
||||
/// when it isn't an alias. Aliases never chain — operator config
|
||||
/// is treated as flat — so this is a single lookup.
|
||||
pub fn resolve_alias<'a>(&'a self, id: &'a str) -> &'a str {
|
||||
self.aliases.get(id).map(String::as_str).unwrap_or(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelProfile {
|
||||
/// True iff this profile's placement constraints can be satisfied
|
||||
/// by the named neuron with the given device topology.
|
||||
///
|
||||
/// Constraints checked:
|
||||
/// - `pinned_on`: non-empty → neuron must be on the list.
|
||||
/// - `min_devices`: neuron must have at least this many devices.
|
||||
/// - `min_device_vram_mb`: at least `min_devices` of the neuron's
|
||||
/// devices must each meet this VRAM floor.
|
||||
pub fn is_feasible_on(&self, neuron_name: &str, devices: &[DeviceInfo]) -> bool {
|
||||
if !self.pinned_on.is_empty() && !self.pinned_on.iter().any(|n| n == neuron_name) {
|
||||
return false;
|
||||
}
|
||||
if (devices.len() as u32) < self.min_devices {
|
||||
return false;
|
||||
}
|
||||
if let Some(min_vram) = self.min_device_vram_mb {
|
||||
let big_enough = devices
|
||||
.iter()
|
||||
.filter(|d| d.vram_total_mb >= min_vram)
|
||||
.count() as u32;
|
||||
if big_enough < self.min_devices {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::discovery::DeviceInfo;
|
||||
|
||||
fn device(idx: u32, vram_mb: u64) -> DeviceInfo {
|
||||
DeviceInfo {
|
||||
index: idx,
|
||||
name: format!("DEV-{idx}"),
|
||||
vram_total_mb: vram_mb,
|
||||
compute_capability: "8.6".into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn profile() -> ModelProfile {
|
||||
ModelProfile {
|
||||
id: "Qwen/Qwen3.6-27B".into(),
|
||||
harness: "candle".into(),
|
||||
quant: None,
|
||||
vram_mb: Some(45_000),
|
||||
min_devices: 2,
|
||||
min_device_vram_mb: Some(24_000),
|
||||
pinned_on: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn feasible_when_two_devices_meet_vram_floor() {
|
||||
let p = profile();
|
||||
let devices = [device(0, 32_000), device(1, 32_000)];
|
||||
assert!(p.is_feasible_on("beast", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infeasible_when_only_one_device() {
|
||||
let p = profile();
|
||||
let devices = [device(0, 64_000)];
|
||||
assert!(!p.is_feasible_on("benjy", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infeasible_when_one_device_underspec() {
|
||||
let p = profile();
|
||||
let devices = [device(0, 32_000), device(1, 12_000)];
|
||||
assert!(!p.is_feasible_on("mixed", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pinned_on_excludes_other_neurons() {
|
||||
let mut p = profile();
|
||||
p.pinned_on = vec!["beast".into()];
|
||||
let devices = [device(0, 32_000), device(1, 32_000)];
|
||||
assert!(p.is_feasible_on("beast", &devices));
|
||||
assert!(!p.is_feasible_on("benjy", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_vram_floor_just_needs_min_devices() {
|
||||
let mut p = profile();
|
||||
p.min_device_vram_mb = None;
|
||||
let devices = [device(0, 1_000), device(1, 1_000)];
|
||||
assert!(p.is_feasible_on("anywhere", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_alias_returns_target_when_alias_present() {
|
||||
let mut cat = ModelCatalogue::default();
|
||||
cat.aliases
|
||||
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
|
||||
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_alias_passes_through_when_not_an_alias() {
|
||||
let mut cat = ModelCatalogue::default();
|
||||
cat.aliases
|
||||
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
|
||||
assert_eq!(cat.resolve_alias("Qwen/Qwen3-8B"), "Qwen/Qwen3-8B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aliases_table_round_trips_through_toml() {
|
||||
let src = r#"
|
||||
[aliases]
|
||||
"helexa/small" = "Qwen/Qwen3-1.7B"
|
||||
"helexa/large" = "Qwen/Qwen3.6-27B"
|
||||
"#;
|
||||
let cat: ModelCatalogue = toml::from_str(src).expect("parse aliases table");
|
||||
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
|
||||
assert_eq!(cat.resolve_alias("helexa/large"), "Qwen/Qwen3.6-27B");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,8 +36,72 @@ pub struct DeviceHealth {
|
||||
|
||||
/// Runtime health response from a neuron endpoint.
|
||||
/// Returned by `GET /health`.
|
||||
///
|
||||
/// `activation` was added in 2026-05-26 to distinguish "process is up
|
||||
/// and reachable" from "process is ready to serve traffic". A `Type=simple`
|
||||
/// systemd unit reports `active` the moment the binary starts — but a
|
||||
/// neuron whose `default_models` list takes minutes to materialise
|
||||
/// won't bind its listener (or, in the new flow, won't have any models
|
||||
/// loaded) until pre-warm completes. The new field is `#[serde(default)]`
|
||||
/// so a pre-2026-05-26 gateway polling a new neuron — or vice versa —
|
||||
/// keeps working.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HealthResponse {
|
||||
pub uptime_secs: u64,
|
||||
pub devices: Vec<DeviceHealth>,
|
||||
#[serde(default)]
|
||||
pub activation: ActivationStatus,
|
||||
}
|
||||
|
||||
/// High-level activation state of the neuron daemon. The HTTP listener
|
||||
/// is bound during both states; what differs is whether the configured
|
||||
/// `default_models` have finished loading.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ActivationState {
|
||||
/// At least one `default_models` entry is still loading. The
|
||||
/// neuron's other endpoints work, but inference against
|
||||
/// not-yet-loaded models will 404.
|
||||
PreWarming,
|
||||
/// Every `default_models` entry has either loaded or failed; the
|
||||
/// neuron is steady-state. Subsequent on-demand loads via
|
||||
/// `/models/load` don't flip back to PreWarming — that field
|
||||
/// reflects the activation-time set only.
|
||||
#[default]
|
||||
Ready,
|
||||
}
|
||||
|
||||
/// Per-model failure record surfaced in [`ActivationStatus::failed`].
|
||||
/// The error string is the rendered anyhow chain at the time of the
|
||||
/// failure; operators read it from `/health` to decide whether to
|
||||
/// retry, edit the spec, or unload+reload.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PreWarmFailure {
|
||||
pub model_id: String,
|
||||
pub error: String,
|
||||
}
|
||||
|
||||
/// Activation-time progress snapshot. All four lists are populated by
|
||||
/// the neuron's pre-warm task and read by the `/health` handler. The
|
||||
/// snapshot is consistent: a model id appears in exactly one of
|
||||
/// `pending`, `in_progress` (as `Option<String>`), `completed`, or
|
||||
/// `failed` at any point in time.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ActivationStatus {
|
||||
pub state: ActivationState,
|
||||
/// Model ids queued but not yet started. Empty in `Ready` state.
|
||||
#[serde(default)]
|
||||
pub pending: Vec<String>,
|
||||
/// Model id currently materialising. None when between models or
|
||||
/// in `Ready` state.
|
||||
#[serde(default)]
|
||||
pub in_progress: Option<String>,
|
||||
/// Model ids that finished loading successfully during this
|
||||
/// activation. Cleared on process restart.
|
||||
#[serde(default)]
|
||||
pub completed: Vec<String>,
|
||||
/// Model ids that failed during this activation, with the rendered
|
||||
/// error chain. Cleared on process restart.
|
||||
#[serde(default)]
|
||||
pub failed: Vec<PreWarmFailure>,
|
||||
}
|
||||
|
||||
@@ -9,13 +9,13 @@ use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for a harness instance on a neuron.
|
||||
///
|
||||
/// All current harnesses are in-process (candle); per-harness tuning
|
||||
/// (cache paths, device policies, etc.) lives in dedicated config
|
||||
/// blocks rather than on this struct.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HarnessConfig {
|
||||
pub name: String,
|
||||
/// Base URL of the harness (e.g. "http://localhost:8080" for mistral.rs).
|
||||
pub endpoint: Option<String>,
|
||||
/// Systemd unit name, if the harness is managed via systemd.
|
||||
pub systemd_unit: Option<String>,
|
||||
}
|
||||
|
||||
/// Health status of a harness process.
|
||||
@@ -47,16 +47,24 @@ pub struct ModelInfo {
|
||||
}
|
||||
|
||||
/// What an inference harness must do, from neuron's perspective.
|
||||
///
|
||||
/// All current harnesses are in-process — they share neuron's address
|
||||
/// space and lifecycle. `start`/`stop` therefore default to no-ops; a
|
||||
/// future process-supervising harness would override them.
|
||||
#[async_trait]
|
||||
pub trait Harness: Send + Sync {
|
||||
/// Human-readable name (e.g. "mistralrs", "llamacpp", "comfyui").
|
||||
/// Human-readable name (e.g. "candle").
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Start the harness process if it is not already running.
|
||||
async fn start(&self, config: &HarnessConfig) -> Result<()>;
|
||||
/// Start the harness. Default no-op for in-process harnesses.
|
||||
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the harness process gracefully.
|
||||
async fn stop(&self) -> Result<()>;
|
||||
/// Stop the harness. Default no-op for in-process harnesses.
|
||||
async fn stop(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Health check. Returns the harness process status.
|
||||
async fn health(&self) -> HarnessHealth;
|
||||
|
||||
@@ -6,4 +6,5 @@ pub mod harness;
|
||||
pub mod metrics;
|
||||
pub mod node;
|
||||
pub mod openai;
|
||||
pub mod responses;
|
||||
pub mod translate;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::discovery::{ActivationStatus, DiscoveryResponse};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
@@ -13,6 +14,18 @@ pub struct NodeState {
|
||||
/// Number of load/unload cycles since last process restart.
|
||||
pub lifecycle_cycles: u32,
|
||||
pub last_poll: Option<DateTime<Utc>>,
|
||||
/// Result of the most recent successful `GET /discovery` against
|
||||
/// this neuron. Cached forever once obtained — device topology is
|
||||
/// invariant for a given neuron process. `None` until the first
|
||||
/// successful poll. Used by the router and `/v1/models` to do
|
||||
/// catalogue × topology feasibility checks.
|
||||
pub discovery: Option<DiscoveryResponse>,
|
||||
/// Last-seen pre-warm progress from this neuron's `/health`
|
||||
/// endpoint. `None` until the first /health poll succeeds. The
|
||||
/// `/v1/models` handler reads `in_progress` + `pending` from here
|
||||
/// to synthesize `Loading` locations so clients see a catalogued
|
||||
/// model that's mid-prewarm as "loading", not "missing".
|
||||
pub activation: Option<ActivationStatus>,
|
||||
}
|
||||
|
||||
/// A model registered on a node, with its runtime status.
|
||||
@@ -27,21 +40,50 @@ pub struct ModelEntry {
|
||||
}
|
||||
|
||||
/// Model lifecycle status.
|
||||
///
|
||||
/// `Loading` is a gateway-side synthetic status: neurons never emit it
|
||||
/// on `/models` (that endpoint only knows about already-loaded handles).
|
||||
/// The gateway populates it from a neuron's `/health` activation
|
||||
/// snapshot so the unified `/v1/models` can distinguish "model is
|
||||
/// catalogued but no one has it" from "model is materialising on
|
||||
/// neuron N right now". Other status values are reported verbatim by
|
||||
/// neurons.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ModelStatus {
|
||||
Loaded,
|
||||
Unloaded,
|
||||
Reloading,
|
||||
Loading,
|
||||
}
|
||||
|
||||
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
|
||||
/// Includes which node(s) host this model and their status.
|
||||
///
|
||||
/// The first four fields (`id`, `object`, `created`, `owned_by`) match
|
||||
/// OpenAI's `/v1/models` shape verbatim, so existing OpenAI-aware
|
||||
/// tooling deserialises this without custom code. The remaining fields
|
||||
/// are helexa-specific extensions — OpenAI clients ignore unknown
|
||||
/// fields and other consumers can read them for placement / debugging.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CortexModelEntry {
|
||||
pub id: String,
|
||||
/// Always `"model"` per OpenAI's contract.
|
||||
pub object: String,
|
||||
/// Which nodes have this model (and their status).
|
||||
/// Unix-second timestamp; cortex stamps this at response time.
|
||||
pub created: u64,
|
||||
/// OpenAI's "publisher" field — `"helexa"` for everything we serve.
|
||||
pub owned_by: String,
|
||||
/// True if any neuron currently has this model loaded. False for
|
||||
/// catalogue entries that are feasible but not yet loaded.
|
||||
pub loaded: bool,
|
||||
/// Neurons whose discovered topology can satisfy this model's
|
||||
/// catalogue placement constraints. Empty for models that are
|
||||
/// loaded somewhere but not present in the catalogue (cortex has
|
||||
/// no feasibility opinion on those).
|
||||
pub feasible_on: Vec<String>,
|
||||
/// Where this model is actually loaded right now. Subset of (or
|
||||
/// disjoint from) `feasible_on` depending on whether the catalogue
|
||||
/// covers this model.
|
||||
pub locations: Vec<ModelLocation>,
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
||||
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
||||
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
||||
//! extension field mistral.rs supports.
|
||||
//! extension field a backend might support.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -22,7 +22,7 @@ pub struct ChatCompletionRequest {
|
||||
pub max_tokens: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
|
||||
/// All other fields (tools, response_format, backend extensions, etc.)
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
346
crates/cortex-core/src/responses.rs
Normal file
346
crates/cortex-core/src/responses.rs
Normal file
@@ -0,0 +1,346 @@
|
||||
//! OpenAI Responses API (`POST /v1/responses`) envelope types.
|
||||
//!
|
||||
//! This is OpenAI's newer chat surface, distinct from
|
||||
//! `/v1/chat/completions` in three ways that matter for us:
|
||||
//!
|
||||
//! 1. **Input shape**. Instead of a `messages` array, the request
|
||||
//! carries `input` — either a plain string (single user turn)
|
||||
//! or an array of typed items (messages, function calls,
|
||||
//! function-call outputs, reasoning blocks, …).
|
||||
//! 2. **Output shape**. The response carries a single `output`
|
||||
//! array of items, each typed. We always emit one
|
||||
//! `OutputItem::Message` containing the assistant's reply (plus,
|
||||
//! when we get there, separate `function_call` items).
|
||||
//! 3. **Streaming events**. Where chat completions stream
|
||||
//! structurally-identical `chat.completion.chunk` frames over
|
||||
//! `data:` lines, Responses streams *named* events
|
||||
//! (`response.created`, `response.output_text.delta`,
|
||||
//! `response.completed`, …) over `event:` + `data:` SSE pairs.
|
||||
//! The wire projector in `neuron::wire::openai_responses` builds
|
||||
//! these from the same [`crate::openai`]-shaped
|
||||
//! `InferenceEvent` stream the chat projector consumes.
|
||||
//!
|
||||
//! Scope cuts for this first cut:
|
||||
//!
|
||||
//! - **`previous_response_id` is rejected at parse time**. Stateful
|
||||
//! chained conversations need a persistence layer we don't have.
|
||||
//! - **Reasoning items are accepted-and-ignored** (no Qwen3
|
||||
//! `<think>` routing yet). Audio and embedded resources are
|
||||
//! rejected as unsupported.
|
||||
//! - **Tool calls** (function_call / function_call_output) are
|
||||
//! carried as round-trip types but the candle harness doesn't
|
||||
//! emit them yet — wired so the surface is in place for the
|
||||
//! day we add proper tool-call extraction.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ── Request ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Body of a `POST /v1/responses` request.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponsesRequest {
|
||||
pub model: String,
|
||||
pub input: ResponsesInput,
|
||||
/// System-prompt-style instructions. The Responses API
|
||||
/// separates these from input so a caller doesn't have to
|
||||
/// build a `system` message item by hand.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub instructions: Option<String>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_output_tokens: Option<u64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
/// Chained-conversation identifier. We don't store responses
|
||||
/// server-side yet; if this is `Some`, the handler returns 400.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub previous_response_id: Option<String>,
|
||||
/// Catch-all for anything we don't model yet (tools, tool_choice,
|
||||
/// reasoning, response_format, …). Lets a client send a
|
||||
/// forward-compatible request without our parser rejecting it.
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
/// `input` is either a single string or an array of typed items.
|
||||
/// `#[serde(untagged)]` so the wire shape `"input": "hi"` and
|
||||
/// `"input": [{...}]` both deserialize.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponsesInput {
|
||||
Text(String),
|
||||
Items(Vec<ResponsesInputItem>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponsesInputItem {
|
||||
/// A user / assistant / system turn.
|
||||
Message {
|
||||
role: String,
|
||||
content: ResponsesMessageContent,
|
||||
},
|
||||
/// Assistant emitted a tool call. Round-trip only — neuron
|
||||
/// doesn't synthesise these yet.
|
||||
FunctionCall {
|
||||
call_id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
},
|
||||
/// User is feeding a tool result back into the model.
|
||||
FunctionCallOutput { call_id: String, output: String },
|
||||
/// Reasoning items emitted by o-series models. Accepted but
|
||||
/// not forwarded to the model — neuron's candle path doesn't
|
||||
/// surface reasoning separately yet.
|
||||
Reasoning {
|
||||
#[serde(default)]
|
||||
content: Vec<Value>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Inside a `Message` item, content is either a plain string or an
|
||||
/// array of typed parts. Mirrors the chat-completions Parts shape.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponsesMessageContent {
|
||||
Text(String),
|
||||
Parts(Vec<ResponsesContentPart>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponsesContentPart {
|
||||
/// Plain text inside a user / system turn.
|
||||
InputText { text: String },
|
||||
/// An image. `image_url` is either a remote URL or a
|
||||
/// `data:image/png;base64,…` URI; the request translator just
|
||||
/// forwards the string.
|
||||
InputImage {
|
||||
image_url: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
detail: Option<String>,
|
||||
},
|
||||
/// Returned text inside an assistant turn — only relevant when
|
||||
/// the caller is feeding an assistant turn back in to continue
|
||||
/// a conversation manually (no `previous_response_id`).
|
||||
OutputText {
|
||||
text: String,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
annotations: Vec<Value>,
|
||||
},
|
||||
}
|
||||
|
||||
// ── Response (non-streaming) ─────────────────────────────────────────
|
||||
|
||||
/// Body of a `POST /v1/responses` response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponsesResponse {
|
||||
pub id: String,
|
||||
/// Always `"response"`.
|
||||
pub object: String,
|
||||
pub created_at: u64,
|
||||
/// `"completed"`, `"incomplete"`, or — for the initial event of
|
||||
/// a streaming response — `"in_progress"`.
|
||||
pub status: String,
|
||||
pub model: String,
|
||||
pub output: Vec<ResponsesOutputItem>,
|
||||
/// Populated on completion; `None` while streaming.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<ResponsesUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponsesOutputItem {
|
||||
Message {
|
||||
id: String,
|
||||
/// Always `"assistant"` for model output.
|
||||
role: String,
|
||||
/// Output content parts. We always emit a single
|
||||
/// `OutputText` today; multi-part output would land here
|
||||
/// once we have e.g. image generation.
|
||||
content: Vec<ResponsesOutputContent>,
|
||||
/// Item-level status. `"in_progress"` while streaming the
|
||||
/// content parts, `"completed"` when done.
|
||||
#[serde(default = "default_item_status")]
|
||||
status: String,
|
||||
},
|
||||
/// Reserved for the day tool-call extraction lands. The wire
|
||||
/// shape mirrors `ResponsesInputItem::FunctionCall`.
|
||||
FunctionCall {
|
||||
id: String,
|
||||
call_id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
#[serde(default = "default_item_status")]
|
||||
status: String,
|
||||
},
|
||||
}
|
||||
|
||||
fn default_item_status() -> String {
|
||||
"completed".into()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponsesOutputContent {
|
||||
OutputText {
|
||||
text: String,
|
||||
/// Citations / inline annotations. Empty today; reserved
|
||||
/// for the day we wire in web search / file search.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
annotations: Vec<Value>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponsesUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
// ── Streaming event names ────────────────────────────────────────────
|
||||
|
||||
/// Event names the SSE projector emits, hoisted as constants so
|
||||
/// the projector and the wire shape stay in sync without
|
||||
/// string-typos. The strings are dictated by OpenAI's published
|
||||
/// Responses API.
|
||||
pub mod events {
|
||||
pub const CREATED: &str = "response.created";
|
||||
/// Fired between `response.created` and the first output-item
|
||||
/// event. Marks "request validated, model is generating" —
|
||||
/// some clients use it to differentiate the "warming up" state
|
||||
/// from "streaming tokens" in their UI.
|
||||
pub const IN_PROGRESS: &str = "response.in_progress";
|
||||
pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added";
|
||||
pub const CONTENT_PART_ADDED: &str = "response.content_part.added";
|
||||
pub const OUTPUT_TEXT_DELTA: &str = "response.output_text.delta";
|
||||
pub const OUTPUT_TEXT_DONE: &str = "response.output_text.done";
|
||||
pub const CONTENT_PART_DONE: &str = "response.content_part.done";
|
||||
pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done";
|
||||
pub const COMPLETED: &str = "response.completed";
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn deserialises_input_string_form() {
|
||||
let raw = r#"{"model": "m", "input": "hello"}"#;
|
||||
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||
match req.input {
|
||||
ResponsesInput::Text(s) => assert_eq!(s, "hello"),
|
||||
other => panic!("expected Text, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialises_input_items_form() {
|
||||
let raw = r#"{
|
||||
"model": "m",
|
||||
"input": [
|
||||
{"type": "message", "role": "user", "content": "hi"}
|
||||
]
|
||||
}"#;
|
||||
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||
match req.input {
|
||||
ResponsesInput::Items(items) => {
|
||||
assert_eq!(items.len(), 1);
|
||||
match &items[0] {
|
||||
ResponsesInputItem::Message { role, content } => {
|
||||
assert_eq!(role, "user");
|
||||
match content {
|
||||
ResponsesMessageContent::Text(t) => assert_eq!(t, "hi"),
|
||||
other => panic!("expected Text content, got {other:?}"),
|
||||
}
|
||||
}
|
||||
other => panic!("expected Message item, got {other:?}"),
|
||||
}
|
||||
}
|
||||
other => panic!("expected Items, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialises_input_with_image() {
|
||||
let raw = r#"{
|
||||
"model": "m",
|
||||
"input": [
|
||||
{"type": "message", "role": "user", "content": [
|
||||
{"type": "input_text", "text": "what is this"},
|
||||
{"type": "input_image", "image_url": "data:image/png;base64,AAA="}
|
||||
]}
|
||||
]
|
||||
}"#;
|
||||
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||
let items = match req.input {
|
||||
ResponsesInput::Items(i) => i,
|
||||
other => panic!("expected Items, got {other:?}"),
|
||||
};
|
||||
let parts = match &items[0] {
|
||||
ResponsesInputItem::Message {
|
||||
content: ResponsesMessageContent::Parts(p),
|
||||
..
|
||||
} => p,
|
||||
other => panic!("expected Parts, got {other:?}"),
|
||||
};
|
||||
assert_eq!(parts.len(), 2);
|
||||
assert!(matches!(
|
||||
&parts[0],
|
||||
ResponsesContentPart::InputText { text } if text == "what is this"
|
||||
));
|
||||
assert!(matches!(
|
||||
&parts[1],
|
||||
ResponsesContentPart::InputImage { image_url, .. }
|
||||
if image_url == "data:image/png;base64,AAA="
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_fields_round_trip_via_extra() {
|
||||
let raw = r#"{
|
||||
"model": "m",
|
||||
"input": "hi",
|
||||
"tools": [{"type": "web_search"}],
|
||||
"reasoning": {"effort": "medium"}
|
||||
}"#;
|
||||
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||
assert!(req.extra.get("tools").is_some());
|
||||
assert!(req.extra.get("reasoning").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trips_through_serde() {
|
||||
let r = ResponsesResponse {
|
||||
id: "resp_1".into(),
|
||||
object: "response".into(),
|
||||
created_at: 1700,
|
||||
status: "completed".into(),
|
||||
model: "m".into(),
|
||||
output: vec![ResponsesOutputItem::Message {
|
||||
id: "msg_1".into(),
|
||||
role: "assistant".into(),
|
||||
content: vec![ResponsesOutputContent::OutputText {
|
||||
text: "hi there".into(),
|
||||
annotations: vec![],
|
||||
}],
|
||||
status: "completed".into(),
|
||||
}],
|
||||
usage: Some(ResponsesUsage {
|
||||
input_tokens: 5,
|
||||
output_tokens: 3,
|
||||
total_tokens: 8,
|
||||
}),
|
||||
};
|
||||
let json = serde_json::to_string(&r).unwrap();
|
||||
let parsed: ResponsesResponse = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.id, "resp_1");
|
||||
assert_eq!(parsed.output.len(), 1);
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ tokio-stream.workspace = true
|
||||
eventsource-stream.workspace = true
|
||||
bytes = "1"
|
||||
urlencoding = "2"
|
||||
url = "2"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util"] }
|
||||
|
||||
@@ -20,6 +20,7 @@ pub fn api_routes() -> Router<Arc<CortexState>> {
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/completions", post(completions))
|
||||
.route("/v1/responses", post(responses))
|
||||
.route("/v1/models", get(list_models))
|
||||
.route("/v1/messages", post(anthropic_messages))
|
||||
.route("/health", get(health))
|
||||
@@ -34,23 +35,94 @@ async fn chat_completions(
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => return error_response(400, "missing 'model' field in request body"),
|
||||
None => {
|
||||
tracing::warn!(
|
||||
handler = "chat_completions",
|
||||
"rejected: missing 'model' field in request body"
|
||||
);
|
||||
return error_response(400, "missing 'model' field in request body");
|
||||
}
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "chat_completions",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"route resolve failed"
|
||||
);
|
||||
// RouteError's Display strings are short and informative
|
||||
// ("model 'X' not found...", "no healthy nodes available")
|
||||
// — fine to surface to the caller. The warn above carries
|
||||
// any extra context for operators.
|
||||
return error_response(404, &e.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
||||
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||
|
||||
let body = rewrite_model_in_body(body, &route.resolved_model_id);
|
||||
proxy_with_metrics(
|
||||
&fleet,
|
||||
&route,
|
||||
"/v1/chat/completions",
|
||||
headers,
|
||||
body,
|
||||
&model_id,
|
||||
&route.resolved_model_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// `POST /v1/responses` — proxy to the appropriate backend node.
|
||||
///
|
||||
/// Same routing shape as [`chat_completions`]: extract `model` from
|
||||
/// the body, resolve to a node, forward verbatim. No translation —
|
||||
/// neuron speaks the Responses API natively (see
|
||||
/// `crates/neuron/src/wire/openai_responses.rs`), so the gateway is
|
||||
/// a pass-through. Streaming and non-streaming are handled
|
||||
/// identically; the upstream `Content-Type` (text/event-stream vs.
|
||||
/// application/json) propagates through the proxy.
|
||||
async fn responses(
|
||||
State(fleet): State<Arc<CortexState>>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
tracing::warn!(
|
||||
handler = "responses",
|
||||
"rejected: missing 'model' field in request body"
|
||||
);
|
||||
return error_response(400, "missing 'model' field in request body");
|
||||
}
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "responses",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"route resolve failed"
|
||||
);
|
||||
return error_response(404, &e.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||
|
||||
let body = rewrite_model_in_body(body, &route.resolved_model_id);
|
||||
proxy_with_metrics(
|
||||
&fleet,
|
||||
&route,
|
||||
"/v1/responses",
|
||||
headers,
|
||||
body,
|
||||
&route.resolved_model_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -63,17 +135,44 @@ async fn completions(
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => return error_response(400, "missing 'model' field in request body"),
|
||||
None => {
|
||||
tracing::warn!(
|
||||
handler = "completions",
|
||||
"rejected: missing 'model' field in request body"
|
||||
);
|
||||
return error_response(400, "missing 'model' field in request body");
|
||||
}
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "completions",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"route resolve failed"
|
||||
);
|
||||
// RouteError's Display strings are short and informative
|
||||
// ("model 'X' not found...", "no healthy nodes available")
|
||||
// — fine to surface to the caller. The warn above carries
|
||||
// any extra context for operators.
|
||||
return error_response(404, &e.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
||||
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||
|
||||
proxy_with_metrics(&fleet, &route, "/v1/completions", headers, body, &model_id).await
|
||||
let body = rewrite_model_in_body(body, &route.resolved_model_id);
|
||||
proxy_with_metrics(
|
||||
&fleet,
|
||||
&route,
|
||||
"/v1/completions",
|
||||
headers,
|
||||
body,
|
||||
&route.resolved_model_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
|
||||
@@ -85,7 +184,14 @@ async fn anthropic_messages(
|
||||
// Parse as Anthropic request.
|
||||
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
error = %e,
|
||||
"rejected: invalid Anthropic request body"
|
||||
);
|
||||
return error_response(400, "invalid Anthropic request body");
|
||||
}
|
||||
};
|
||||
|
||||
let model_id = anth_req.model.clone();
|
||||
@@ -95,18 +201,43 @@ async fn anthropic_messages(
|
||||
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
|
||||
let openai_body = match serde_json::to_vec(&openai_req) {
|
||||
Ok(b) => Bytes::from(b),
|
||||
Err(e) => return error_response(500, &format!("translation error: {e}")),
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"internal: failed to serialise translated OpenAI request"
|
||||
);
|
||||
return error_response(500, "internal translation error");
|
||||
}
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"route resolve failed"
|
||||
);
|
||||
// RouteError's Display strings are short and informative
|
||||
// ("model 'X' not found...", "no healthy nodes available")
|
||||
// — fine to surface to the caller. The warn above carries
|
||||
// any extra context for operators.
|
||||
return error_response(404, &e.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
||||
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||
|
||||
// Swap the alias for the concrete id in the translated body so
|
||||
// neuron's harness sees a model name that matches what it has
|
||||
// loaded.
|
||||
let openai_body = rewrite_model_in_body(openai_body, &route.resolved_model_id);
|
||||
|
||||
let labels = [
|
||||
("model", model_id.clone()),
|
||||
("model", route.resolved_model_id.clone()),
|
||||
("node", route.node_name.clone()),
|
||||
];
|
||||
metrics::counter!("cortex_requests_total", &labels).increment(1);
|
||||
@@ -133,14 +264,25 @@ async fn anthropic_messages(
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
// forward_request already warn'd with the wire-level
|
||||
// detail; no need to log again here.
|
||||
e.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: proxy, buffer full response, translate back to Anthropic.
|
||||
let target_url = format!("{}/v1/chat/completions", route.endpoint);
|
||||
tracing::info!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
cold_start = route.cold_start,
|
||||
"proxying request"
|
||||
);
|
||||
let upstream_resp = fleet
|
||||
.http_client
|
||||
.post(format!("{}/v1/chat/completions", route.endpoint))
|
||||
.post(&target_url)
|
||||
.body(openai_body)
|
||||
.header("content-type", "application/json")
|
||||
.send()
|
||||
@@ -150,22 +292,49 @@ async fn anthropic_messages(
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
return error_response(502, &format!("upstream request failed: {e}"));
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
error = %e,
|
||||
"upstream request failed (network)"
|
||||
);
|
||||
return error_response(502, "upstream request failed");
|
||||
}
|
||||
};
|
||||
|
||||
if !upstream_resp.status().is_success() {
|
||||
let upstream_status = upstream_resp.status();
|
||||
if !upstream_status.is_success() {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
let status = upstream_resp.status().as_u16();
|
||||
let status = upstream_status.as_u16();
|
||||
let body = upstream_resp.text().await.unwrap_or_default();
|
||||
return error_response(status, &format!("upstream error: {body}"));
|
||||
let body_snippet = body.chars().take(512).collect::<String>();
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
status,
|
||||
body = %body_snippet,
|
||||
"upstream returned non-2xx"
|
||||
);
|
||||
return error_response(status, &format!("upstream returned {status}"));
|
||||
}
|
||||
|
||||
let body_bytes = match upstream_resp.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
return error_response(502, &format!("failed to read upstream response: {e}"));
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
error = %e,
|
||||
"failed to read upstream response body"
|
||||
);
|
||||
return error_response(502, "failed to read upstream response");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -174,7 +343,20 @@ async fn anthropic_messages(
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
return error_response(502, &format!("failed to parse upstream response: {e}"));
|
||||
let body_snippet = String::from_utf8_lossy(&body_bytes)
|
||||
.chars()
|
||||
.take(512)
|
||||
.collect::<String>();
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
error = %e,
|
||||
body = %body_snippet,
|
||||
"failed to parse upstream response as OpenAI ChatCompletionResponse"
|
||||
);
|
||||
return error_response(502, "malformed upstream response");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -185,12 +367,62 @@ async fn anthropic_messages(
|
||||
}
|
||||
}
|
||||
|
||||
/// `GET /v1/models` — aggregate models from all nodes.
|
||||
/// `GET /v1/models` — union of (catalogue × topology feasibility) and
|
||||
/// (currently loaded somewhere). The result is what the fleet *could*
|
||||
/// serve, not just what's already loaded — so OpenAI-compatible tools
|
||||
/// see every model the operator has provisioned, and cortex
|
||||
/// transparently cold-loads the first time one is requested.
|
||||
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||
use std::collections::HashMap;
|
||||
let now = Utc::now().timestamp() as u64;
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
|
||||
std::collections::HashMap::new();
|
||||
let catalogue = &fleet.catalogue;
|
||||
|
||||
let mut entries: HashMap<String, CortexModelEntry> = HashMap::new();
|
||||
|
||||
// Pass 1: catalogue × topology. For every catalogue profile, find
|
||||
// healthy neurons whose discovered devices satisfy the profile.
|
||||
// Catalogue-defined models surface here even if nothing has loaded
|
||||
// them yet — that's the point of the unified endpoint.
|
||||
for profile in &catalogue.models {
|
||||
let mut feasible_on = Vec::new();
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
let Some(disc) = node.discovery.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
if profile.is_feasible_on(&node.name, &disc.devices) {
|
||||
feasible_on.push(node.name.clone());
|
||||
}
|
||||
}
|
||||
if feasible_on.is_empty() {
|
||||
// The catalogue lists this model but no neuron's topology
|
||||
// matches — surface it as not-loaded with no feasible
|
||||
// location. Hides nothing; lets operators see why a
|
||||
// configured model isn't reachable.
|
||||
feasible_on.clear();
|
||||
}
|
||||
entries.insert(
|
||||
profile.id.clone(),
|
||||
CortexModelEntry {
|
||||
id: profile.id.clone(),
|
||||
object: "model".into(),
|
||||
created: now,
|
||||
owned_by: "helexa".into(),
|
||||
loaded: false,
|
||||
feasible_on,
|
||||
locations: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Pass 2: layer the actually-loaded state on top. For each
|
||||
// (node, model) entry, attach a ModelLocation. If the model isn't
|
||||
// in the catalogue, create a new CortexModelEntry from scratch —
|
||||
// cortex doesn't refuse to surface a manually-loaded model just
|
||||
// because the operator didn't enumerate it in models.toml.
|
||||
for node in nodes.values() {
|
||||
for (model_id, entry) in &node.models {
|
||||
let location = ModelLocation {
|
||||
@@ -198,19 +430,108 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||
status: entry.status,
|
||||
vram_estimate_mb: entry.vram_estimate_mb,
|
||||
};
|
||||
model_map
|
||||
let was_loaded = matches!(entry.status, cortex_core::node::ModelStatus::Loaded);
|
||||
entries
|
||||
.entry(model_id.clone())
|
||||
.and_modify(|e| e.locations.push(location.clone()))
|
||||
.and_modify(|e| {
|
||||
e.locations.push(location.clone());
|
||||
if was_loaded {
|
||||
e.loaded = true;
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| CortexModelEntry {
|
||||
id: model_id.clone(),
|
||||
object: "model".into(),
|
||||
created: now,
|
||||
owned_by: "helexa".into(),
|
||||
loaded: was_loaded,
|
||||
// Not in catalogue — cortex has no opinion on
|
||||
// feasibility; leave empty.
|
||||
feasible_on: Vec::new(),
|
||||
locations: vec![location],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let data: Vec<Value> = model_map.values().map(|e| json!(e)).collect();
|
||||
// Pass 3: surface pre-warming models. Each neuron's `/health`
|
||||
// activation snapshot (polled separately from /models) reports
|
||||
// `in_progress` (the model currently materialising) and `pending`
|
||||
// (queued behind it). Neither appears on the neuron's `/models`
|
||||
// yet — that endpoint only knows about fully-loaded handles — so
|
||||
// without this pass a client polling `/v1/models` during pre-warm
|
||||
// sees Qwen3.6-27B with no location and concludes "not there".
|
||||
// Synthesising a Loading location instead tells clients the model
|
||||
// is on its way. Idempotent against Pass 2: if a Loading location
|
||||
// for this node already exists (shouldn't, but be safe) we skip.
|
||||
for node in nodes.values() {
|
||||
let Some(activation) = node.activation.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
let mut loading_ids: Vec<&str> = Vec::new();
|
||||
if let Some(id) = activation.in_progress.as_deref() {
|
||||
loading_ids.push(id);
|
||||
}
|
||||
for id in &activation.pending {
|
||||
loading_ids.push(id.as_str());
|
||||
}
|
||||
for model_id in loading_ids {
|
||||
let location = ModelLocation {
|
||||
node: node.name.clone(),
|
||||
status: cortex_core::node::ModelStatus::Loading,
|
||||
vram_estimate_mb: None,
|
||||
};
|
||||
entries
|
||||
.entry(model_id.to_string())
|
||||
.and_modify(|e| {
|
||||
let already = e.locations.iter().any(|l| {
|
||||
l.node == node.name && l.status == cortex_core::node::ModelStatus::Loading
|
||||
});
|
||||
if !already {
|
||||
e.locations.push(location.clone());
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| CortexModelEntry {
|
||||
id: model_id.to_string(),
|
||||
object: "model".into(),
|
||||
created: now,
|
||||
owned_by: "helexa".into(),
|
||||
loaded: false,
|
||||
feasible_on: Vec::new(),
|
||||
locations: vec![location],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 4: surface aliases as their own entries pointing at the
|
||||
// same locations as the target id, so a client browsing /v1/models
|
||||
// sees "helexa/small" / "helexa/balanced" / "helexa/large" (or
|
||||
// whatever the operator defined) and can request inference
|
||||
// against them directly. Aliases that point at unknown targets
|
||||
// are skipped — surfacing a dead alias would be misleading.
|
||||
for (alias, target) in &catalogue.aliases {
|
||||
let Some(target_entry) = entries.get(target).cloned() else {
|
||||
tracing::warn!(
|
||||
alias = alias,
|
||||
target = target,
|
||||
"alias points at a model not present in catalogue or fleet; skipping"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
entries.insert(
|
||||
alias.clone(),
|
||||
CortexModelEntry {
|
||||
id: alias.clone(),
|
||||
object: "model".into(),
|
||||
created: now,
|
||||
owned_by: "helexa".into(),
|
||||
loaded: target_entry.loaded,
|
||||
feasible_on: target_entry.feasible_on,
|
||||
locations: target_entry.locations,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let data: Vec<Value> = entries.values().map(|e| json!(e)).collect();
|
||||
Json(json!({
|
||||
"object": "list",
|
||||
"data": data,
|
||||
@@ -265,6 +586,9 @@ async fn proxy_with_metrics(
|
||||
}
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
// proxy::forward_request already warn'd with wire-level
|
||||
// detail (target URL, error, status). ProxyError::into_response
|
||||
// now returns a generic message — no body leak.
|
||||
e.into_response()
|
||||
}
|
||||
}
|
||||
@@ -285,6 +609,38 @@ fn extract_model(body: &[u8]) -> Option<String> {
|
||||
v.get("model")?.as_str().map(|s| s.to_string())
|
||||
}
|
||||
|
||||
/// Rewrite the `model` field of an OpenAI-style JSON request body to
|
||||
/// the resolved concrete id. Returns the original bytes if `new_model`
|
||||
/// matches what's already there or the body fails to parse — the
|
||||
/// caller has already extracted `model` via `extract_model`, so a
|
||||
/// parse failure here would only happen on a body the client crafted
|
||||
/// to defeat us, and we'd rather proxy it unchanged than 500.
|
||||
///
|
||||
/// Needed because neuron rejects requests whose `model` field doesn't
|
||||
/// match a loaded model, so a client that sends `model: "helexa/small"`
|
||||
/// would hit a 404 at the harness unless we swap it for the concrete
|
||||
/// id the alias resolved to.
|
||||
fn rewrite_model_in_body(body: Bytes, new_model: &str) -> Bytes {
|
||||
let Ok(mut v) = serde_json::from_slice::<Value>(&body) else {
|
||||
return body;
|
||||
};
|
||||
let needs_rewrite = v
|
||||
.get("model")
|
||||
.and_then(|m| m.as_str())
|
||||
.map(|m| m != new_model)
|
||||
.unwrap_or(false);
|
||||
if !needs_rewrite {
|
||||
return body;
|
||||
}
|
||||
if let Value::Object(obj) = &mut v {
|
||||
obj.insert("model".into(), Value::String(new_model.to_string()));
|
||||
}
|
||||
match serde_json::to_vec(&v) {
|
||||
Ok(bytes) => Bytes::from(bytes),
|
||||
Err(_) => body,
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: u16, message: &str) -> Response {
|
||||
let code = axum::http::StatusCode::from_u16(status)
|
||||
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
use crate::state::CortexState;
|
||||
use chrono::Utc;
|
||||
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||
use cortex_core::harness::ModelInfo;
|
||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||
use std::sync::Arc;
|
||||
@@ -25,7 +26,59 @@ pub async fn poll_once(fleet: &CortexState) {
|
||||
}
|
||||
}
|
||||
|
||||
/// One-shot fetch of `GET /discovery`. Cached on the NodeState forever
|
||||
/// after the first success — topology is invariant for a given neuron
|
||||
/// process. Skipped when the cache is already populated.
|
||||
async fn maybe_poll_discovery(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
{
|
||||
let nodes = fleet.nodes.read().await;
|
||||
match nodes.get(name) {
|
||||
Some(n) if n.discovery.is_some() => return,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let url = format!("{endpoint}/discovery");
|
||||
let resp = match fleet
|
||||
.http_client
|
||||
.get(&url)
|
||||
.timeout(Duration::from_secs(5))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) if r.status().is_success() => r,
|
||||
Ok(r) => {
|
||||
tracing::debug!(node = name, status = %r.status(), "discovery probe non-success");
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(node = name, error = %e, "discovery probe unreachable");
|
||||
return;
|
||||
}
|
||||
};
|
||||
match resp.json::<DiscoveryResponse>().await {
|
||||
Ok(d) => {
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
if let Some(node) = nodes.get_mut(name) {
|
||||
tracing::info!(
|
||||
node = name,
|
||||
hostname = %d.hostname,
|
||||
devices = d.devices.len(),
|
||||
"discovery cached"
|
||||
);
|
||||
node.discovery = Some(d);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(node = name, error = %e, "failed to parse /discovery response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
// Topology first — cheap once cached, and the router needs it to
|
||||
// route requests against catalogue entries that aren't loaded yet.
|
||||
maybe_poll_discovery(fleet, name, endpoint).await;
|
||||
|
||||
let url = format!("{endpoint}/models");
|
||||
|
||||
let result = fleet
|
||||
@@ -89,6 +142,51 @@ async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
node.healthy = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Release the write lock before the next HTTP call.
|
||||
drop(nodes);
|
||||
|
||||
// Poll /health for the activation snapshot. We don't want this to
|
||||
// flip the node to unhealthy on its own — a neuron that's serving
|
||||
// /models fine is still operational even if /health is briefly
|
||||
// unavailable — so failures are debug-level and leave the existing
|
||||
// activation reading in place.
|
||||
poll_health(fleet, name, endpoint).await;
|
||||
}
|
||||
|
||||
/// Fetch `/health` and stash the activation snapshot on NodeState.
|
||||
/// Decoupled from the /models poll so a /health glitch doesn't mark
|
||||
/// the neuron unhealthy or evict the model list.
|
||||
async fn poll_health(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
let url = format!("{endpoint}/health");
|
||||
let resp = match fleet
|
||||
.http_client
|
||||
.get(&url)
|
||||
.timeout(Duration::from_secs(5))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) if r.status().is_success() => r,
|
||||
Ok(r) => {
|
||||
tracing::debug!(node = name, status = %r.status(), "/health probe non-success");
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(node = name, error = %e, "/health probe failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
match resp.json::<HealthResponse>().await {
|
||||
Ok(h) => {
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
if let Some(node) = nodes.get_mut(name) {
|
||||
node.activation = Some(h.activation);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(node = name, error = %e, "failed to parse /health response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_status(s: &str) -> ModelStatus {
|
||||
@@ -96,6 +194,7 @@ fn parse_status(s: &str) -> ModelStatus {
|
||||
"loaded" => ModelStatus::Loaded,
|
||||
"unloaded" => ModelStatus::Unloaded,
|
||||
"reloading" => ModelStatus::Reloading,
|
||||
"loading" => ModelStatus::Loading,
|
||||
_ => ModelStatus::Loaded,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//! Streaming HTTP reverse proxy to mistral.rs backends.
|
||||
//! Streaming HTTP reverse proxy to neuron backends.
|
||||
//!
|
||||
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
||||
//! The proxy captures timing information for metrics but does not
|
||||
@@ -12,6 +12,13 @@ use axum::response::{IntoResponse, Response};
|
||||
use reqwest::Client;
|
||||
|
||||
/// Proxy a request body to the resolved backend node and stream the response.
|
||||
///
|
||||
/// Logging contract: every call emits exactly one structured event at
|
||||
/// info / warn level for operator visibility, regardless of outcome.
|
||||
/// Network-level failures and non-2xx upstream statuses are warn'd here
|
||||
/// (closest to the wire); the user-facing response carries only the
|
||||
/// status code and a generic message — implementation detail (body,
|
||||
/// error chain) lives in the log, never in the API surface.
|
||||
pub async fn forward_request(
|
||||
client: &Client,
|
||||
route: &RouteDecision,
|
||||
@@ -37,10 +44,33 @@ pub async fn forward_request(
|
||||
req_builder = req_builder.header(key, value);
|
||||
}
|
||||
|
||||
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
|
||||
let upstream_resp = match req_builder.send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
node = %route.node_name,
|
||||
url = %url,
|
||||
error = %e,
|
||||
"proxy: upstream request failed (network)"
|
||||
);
|
||||
return Err(ProxyError::Upstream(e));
|
||||
}
|
||||
};
|
||||
|
||||
let status =
|
||||
StatusCode::from_u16(upstream_resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let upstream_status = upstream_resp.status();
|
||||
if !upstream_status.is_success() {
|
||||
// Streaming body — can't snippet without breaking the stream
|
||||
// pass-through. Log status + URL; the client still gets the
|
||||
// upstream status, just without the leaked body.
|
||||
tracing::warn!(
|
||||
node = %route.node_name,
|
||||
url = %url,
|
||||
status = upstream_status.as_u16(),
|
||||
"proxy: upstream returned non-2xx"
|
||||
);
|
||||
}
|
||||
|
||||
let status = StatusCode::from_u16(upstream_status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
|
||||
let resp_headers = upstream_resp.headers().clone();
|
||||
let stream = upstream_resp.bytes_stream();
|
||||
@@ -52,28 +82,37 @@ pub async fn forward_request(
|
||||
response = response.header(key, value);
|
||||
}
|
||||
|
||||
response
|
||||
.body(body)
|
||||
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
|
||||
response.body(body).map_err(|e| {
|
||||
tracing::warn!(
|
||||
node = %route.node_name,
|
||||
url = %url,
|
||||
error = %e,
|
||||
"proxy: failed to build response"
|
||||
);
|
||||
ProxyError::ResponseBuild(e.to_string())
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProxyError {
|
||||
#[error("upstream request failed: {0}")]
|
||||
#[error("upstream request failed")]
|
||||
Upstream(reqwest::Error),
|
||||
#[error("failed to build response: {0}")]
|
||||
#[error("failed to build response")]
|
||||
ResponseBuild(String),
|
||||
}
|
||||
|
||||
impl IntoResponse for ProxyError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = match &self {
|
||||
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
|
||||
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
let (status, message) = match &self {
|
||||
ProxyError::Upstream(_) => (StatusCode::BAD_GATEWAY, "upstream request failed"),
|
||||
ProxyError::ResponseBuild(_) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"failed to build response",
|
||||
),
|
||||
};
|
||||
let body = serde_json::json!({
|
||||
"error": {
|
||||
"message": self.to_string(),
|
||||
"message": message,
|
||||
"type": "proxy_error",
|
||||
}
|
||||
});
|
||||
|
||||
@@ -2,13 +2,21 @@
|
||||
//!
|
||||
//! Given a model ID from an inbound request, determine which node should
|
||||
//! handle it. Priority:
|
||||
//! 1. Node where the model is currently `Loaded`
|
||||
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
|
||||
//! 3. Error: model not found on any node
|
||||
//! 1. Node where the model is currently `Loaded` → use it.
|
||||
//! 2. Node where the model is `Unloaded` → use it; neuron's existing
|
||||
//! lazy-load behaviour will reload before serving the request.
|
||||
//! 3. Model is in the catalogue → pick a feasible neuron, call
|
||||
//! `POST /models/load`, wait for the load to complete, then
|
||||
//! proxy. First-request cold-load latency is acceptable per the
|
||||
//! unified-endpoint contract.
|
||||
//! 4. Not in catalogue, not loaded anywhere → 404.
|
||||
|
||||
use crate::state::CortexState;
|
||||
use cortex_core::catalogue::ModelProfile;
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use cortex_core::node::ModelStatus;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// The routing decision: which node endpoint to proxy the request to.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -16,62 +24,292 @@ pub struct RouteDecision {
|
||||
pub node_name: String,
|
||||
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
|
||||
pub endpoint: String,
|
||||
/// Whether the model will need to load (cold start).
|
||||
/// Whether the model will need to load (cold start). Set to true
|
||||
/// when we proxied to an `Unloaded` node (lazy load on neuron) or
|
||||
/// when we just triggered an explicit cold-load via the catalogue
|
||||
/// path.
|
||||
pub cold_start: bool,
|
||||
/// The concrete model id we actually routed to. Equal to the
|
||||
/// caller's requested id unless an alias was resolved (e.g. caller
|
||||
/// asked for `helexa/small`, this carries `Qwen/Qwen3-1.7B`). The
|
||||
/// handler uses this to rewrite the request body's `model` field
|
||||
/// before proxying — neurons reject requests where the body's
|
||||
/// model name doesn't match a loaded model.
|
||||
pub resolved_model_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RouteError {
|
||||
#[error("model '{0}' not found on any node")]
|
||||
#[error("model '{0}' not found on any node and not in catalogue")]
|
||||
ModelNotFound(String),
|
||||
#[error("no healthy nodes available")]
|
||||
NoHealthyNodes,
|
||||
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
|
||||
EndpointResolveFailed(String, String),
|
||||
#[error(
|
||||
"model '{model_id}' is in the catalogue but no healthy neuron's topology satisfies its constraints"
|
||||
)]
|
||||
NoFeasibleNeuron { model_id: String },
|
||||
#[error("cold-load of '{model_id}' on '{node}' failed: {message}")]
|
||||
ColdLoadFailed {
|
||||
model_id: String,
|
||||
node: String,
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Resolve which node should serve a request for the given model.
|
||||
/// Asks the neuron for the inference endpoint after selecting a node.
|
||||
pub async fn resolve(
|
||||
fleet: &Arc<CortexState>,
|
||||
model_id: &str,
|
||||
requested_model_id: &str,
|
||||
) -> Result<RouteDecision, RouteError> {
|
||||
let (node_name, neuron_endpoint, cold_start) = {
|
||||
// Alias resolution first — swap `helexa/small` (etc.) for the
|
||||
// concrete id before any node lookups so the rest of routing,
|
||||
// loading, and metrics deal in concrete ids only. `resolve_alias`
|
||||
// returns the input verbatim when it isn't an alias.
|
||||
let model_id = fleet.catalogue.resolve_alias(requested_model_id);
|
||||
if model_id != requested_model_id {
|
||||
tracing::debug!(
|
||||
requested = requested_model_id,
|
||||
resolved = model_id,
|
||||
"alias resolved"
|
||||
);
|
||||
}
|
||||
// Snapshot loaded / unloaded state from the poller cache.
|
||||
let (loaded_route, unloaded_route, any_healthy) = {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
|
||||
let mut loaded_candidate = None;
|
||||
let mut unloaded_candidate = None;
|
||||
|
||||
let mut loaded_route = None;
|
||||
let mut unloaded_route = None;
|
||||
let mut any_healthy = false;
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
any_healthy = true;
|
||||
if let Some(entry) = node.models.get(model_id) {
|
||||
match entry.status {
|
||||
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false));
|
||||
loaded_route = Some((node.name.clone(), node.endpoint.clone(), false));
|
||||
break;
|
||||
}
|
||||
ModelStatus::Unloaded => {
|
||||
if unloaded_candidate.is_none() {
|
||||
unloaded_candidate =
|
||||
Some((node.name.clone(), node.endpoint.clone(), true));
|
||||
if unloaded_route.is_none() {
|
||||
unloaded_route = Some((node.name.clone(), node.endpoint.clone(), true));
|
||||
}
|
||||
}
|
||||
// Loading is gateway-synthesised from neuron's
|
||||
// activation snapshot; it never appears on the
|
||||
// wire from neuron's `/models`. Skip — the model
|
||||
// isn't actually servable yet. The pre-existing
|
||||
// race (catalogue cold_load fires a parallel
|
||||
// /models/load against the in-flight load) is no
|
||||
// worse than before; fixing it needs neuron-side
|
||||
// in-flight tracking on /models/load itself.
|
||||
ModelStatus::Loading => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
(loaded_route, unloaded_route, any_healthy)
|
||||
};
|
||||
|
||||
if !any_healthy {
|
||||
return Err(RouteError::NoHealthyNodes);
|
||||
}
|
||||
|
||||
// Priority 1: already loaded.
|
||||
if let Some((node_name, neuron_endpoint, cold_start)) = loaded_route {
|
||||
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
|
||||
}
|
||||
|
||||
// Priority 2: known to neuron but unloaded (neuron's lazy load).
|
||||
if let Some((node_name, neuron_endpoint, cold_start)) = unloaded_route {
|
||||
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
|
||||
}
|
||||
|
||||
// Priority 3: catalogue × topology cold-load.
|
||||
if let Some(profile) = fleet.catalogue.get(model_id) {
|
||||
let (node_name, neuron_endpoint) = pick_feasible_neuron(fleet, profile).await?;
|
||||
cold_load(fleet, &node_name, &neuron_endpoint, profile).await?;
|
||||
return finish(fleet, &node_name, &neuron_endpoint, model_id, true).await;
|
||||
}
|
||||
|
||||
Err(RouteError::ModelNotFound(model_id.to_string()))
|
||||
}
|
||||
|
||||
/// Pick a healthy neuron whose discovered topology satisfies the
|
||||
/// profile. Preference order:
|
||||
/// 1. A neuron from `profile.pinned_on` that is healthy + feasible.
|
||||
/// 2. Otherwise, any healthy + feasible neuron, stable by name.
|
||||
async fn pick_feasible_neuron(
|
||||
fleet: &Arc<CortexState>,
|
||||
profile: &ModelProfile,
|
||||
) -> Result<(String, String), RouteError> {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let mut candidates: Vec<(String, String, bool)> = Vec::new();
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
let Some(disc) = node.discovery.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
if !profile.is_feasible_on(&node.name, &disc.devices) {
|
||||
continue;
|
||||
}
|
||||
let pinned = profile.pinned_on.iter().any(|n| n == &node.name);
|
||||
candidates.push((node.name.clone(), node.endpoint.clone(), pinned));
|
||||
}
|
||||
candidates.sort_by(|a, b| {
|
||||
b.2.cmp(&a.2) // pinned first (true > false)
|
||||
.then(a.0.cmp(&b.0))
|
||||
});
|
||||
let pick = candidates.into_iter().next();
|
||||
pick.map(|(n, e, _)| (n, e))
|
||||
.ok_or_else(|| RouteError::NoFeasibleNeuron {
|
||||
model_id: profile.id.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Issue `POST {endpoint}/models/load` for this profile on this neuron,
|
||||
/// blocking until the load completes (neuron's load endpoint is
|
||||
/// synchronous — it returns 200 once VRAM is materialised). On success
|
||||
/// also inserts a `Loaded` entry into the local NodeState cache so the
|
||||
/// caller's subsequent endpoint lookup sees the new model without
|
||||
/// waiting for the next poll cycle.
|
||||
async fn cold_load(
|
||||
fleet: &Arc<CortexState>,
|
||||
node_name: &str,
|
||||
neuron_endpoint: &str,
|
||||
profile: &ModelProfile,
|
||||
) -> Result<(), RouteError> {
|
||||
let spec = profile_to_spec(fleet, node_name, profile).await;
|
||||
let url = format!("{neuron_endpoint}/models/load");
|
||||
tracing::info!(model = %profile.id, node = node_name, "cold-loading via /models/load");
|
||||
|
||||
// Generous timeout: a fresh download + safetensors mmap + device
|
||||
// copy for a 30B-class dense model can comfortably exceed 5 min on
|
||||
// a slow link. The HTTP client's own default already covers most
|
||||
// of this; pin a longer per-request bound just here.
|
||||
let resp = match fleet
|
||||
.http_client
|
||||
.post(&url)
|
||||
.timeout(Duration::from_secs(1800))
|
||||
.json(&spec)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Err(RouteError::ColdLoadFailed {
|
||||
model_id: profile.id.clone(),
|
||||
node: node_name.to_string(),
|
||||
message: format!("HTTP request failed: {e}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
// Neuron returns 400 "already loaded" when two concurrent
|
||||
// requests race the same model. Treat that as success — both
|
||||
// requests effectively achieved the same end state.
|
||||
if body.contains("already loaded") {
|
||||
tracing::info!(
|
||||
model = %profile.id,
|
||||
node = node_name,
|
||||
"cold-load saw 'already loaded' — treating as success"
|
||||
);
|
||||
} else {
|
||||
return Err(RouteError::ColdLoadFailed {
|
||||
model_id: profile.id.clone(),
|
||||
node: node_name.to_string(),
|
||||
message: format!("HTTP {status}: {body}"),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
tracing::info!(model = %profile.id, node = node_name, "cold-load returned 200");
|
||||
}
|
||||
|
||||
// Warm the cache: insert a Loaded ModelEntry so the next
|
||||
// resolve() finds the model without waiting for the poll loop.
|
||||
{
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
if let Some(node) = nodes.get_mut(node_name) {
|
||||
node.models.insert(
|
||||
profile.id.clone(),
|
||||
cortex_core::node::ModelEntry {
|
||||
id: profile.id.clone(),
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: Some(chrono::Utc::now()),
|
||||
vram_estimate_mb: profile.vram_mb,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Translate a `ModelProfile` to a `ModelSpec` neuron's /models/load
|
||||
/// accepts. Devices are picked from the neuron's discovered topology —
|
||||
/// the first `min_devices` indices that meet `min_device_vram_mb`.
|
||||
async fn profile_to_spec(
|
||||
fleet: &Arc<CortexState>,
|
||||
node_name: &str,
|
||||
profile: &ModelProfile,
|
||||
) -> ModelSpec {
|
||||
let devices = {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let mut picked: Vec<u32> = Vec::new();
|
||||
if let Some(node) = nodes.get(node_name)
|
||||
&& let Some(disc) = &node.discovery
|
||||
{
|
||||
let min_vram = profile.min_device_vram_mb.unwrap_or(0);
|
||||
for d in &disc.devices {
|
||||
if d.vram_total_mb >= min_vram {
|
||||
picked.push(d.index);
|
||||
if picked.len() as u32 >= profile.min_devices {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
|
||||
if nodes.values().any(|n| n.healthy) {
|
||||
RouteError::ModelNotFound(model_id.to_string())
|
||||
} else {
|
||||
RouteError::NoHealthyNodes
|
||||
}
|
||||
})?
|
||||
if picked.is_empty() {
|
||||
// Fall back to a 0..min_devices default; pick_feasible_neuron
|
||||
// already verified the topology satisfies the constraints,
|
||||
// so this only fires if discovery raced or was lost.
|
||||
(0..profile.min_devices).collect()
|
||||
} else {
|
||||
picked
|
||||
}
|
||||
};
|
||||
|
||||
// Ask the neuron for the inference endpoint for this model.
|
||||
let tensor_parallel = if profile.min_devices > 1 {
|
||||
Some(profile.min_devices)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
ModelSpec {
|
||||
model_id: profile.id.clone(),
|
||||
harness: profile.harness.clone(),
|
||||
quant: profile.quant.clone(),
|
||||
tensor_parallel,
|
||||
devices: Some(devices),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
|
||||
/// build the final `RouteDecision`. Shared by all three priority
|
||||
/// branches above.
|
||||
async fn finish(
|
||||
fleet: &Arc<CortexState>,
|
||||
node_name: &str,
|
||||
neuron_endpoint: &str,
|
||||
model_id: &str,
|
||||
cold_start: bool,
|
||||
) -> Result<RouteDecision, RouteError> {
|
||||
let endpoint_url = format!(
|
||||
"{}/models/{}/endpoint",
|
||||
neuron_endpoint,
|
||||
@@ -89,13 +327,83 @@ pub async fn resolve(
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let endpoint = inference_endpoint.ok_or_else(|| {
|
||||
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone())
|
||||
let raw = inference_endpoint.ok_or_else(|| {
|
||||
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.to_string())
|
||||
})?;
|
||||
|
||||
// Rewrite loopback inference URLs to use the configured neuron host.
|
||||
// Neuron's default bind_url is `http://localhost:13131` (it can't
|
||||
// reliably know its own externally-resolvable name). Cortex sees a
|
||||
// URL that's only meaningful from the neuron host's own perspective;
|
||||
// proxying directly to localhost from a different cortex host would
|
||||
// hit nothing. Keep neuron's port and path (a future harness could
|
||||
// serve inference on a different port than the management API), but
|
||||
// swap the host for the one in cortex.toml.
|
||||
let endpoint = rewrite_loopback_host(&raw, neuron_endpoint).unwrap_or(raw);
|
||||
|
||||
Ok(RouteDecision {
|
||||
node_name,
|
||||
node_name: node_name.to_string(),
|
||||
endpoint,
|
||||
cold_start,
|
||||
resolved_model_id: model_id.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// If `inference_url`'s host is a loopback name (localhost / 127.0.0.1 /
|
||||
/// 0.0.0.0 / ::1), return a copy with the host replaced by
|
||||
/// `neuron_endpoint`'s host. Otherwise return None and the caller falls
|
||||
/// back to the inference URL as-is.
|
||||
fn rewrite_loopback_host(inference_url: &str, neuron_endpoint: &str) -> Option<String> {
|
||||
let inf = url::Url::parse(inference_url).ok()?;
|
||||
let inf_host = inf.host_str()?;
|
||||
let is_loopback = matches!(inf_host, "localhost" | "127.0.0.1" | "0.0.0.0" | "::1");
|
||||
if !is_loopback {
|
||||
return None;
|
||||
}
|
||||
let neuron = url::Url::parse(neuron_endpoint).ok()?;
|
||||
let new_host = neuron.host_str()?;
|
||||
let mut out = inf.clone();
|
||||
out.set_host(Some(new_host)).ok()?;
|
||||
// url::Url::to_string normalises an empty path to "/", which then
|
||||
// breaks downstream callers that do format!("{endpoint}/v1/...")
|
||||
// and produce a double slash. The proxy URL is treated as a base
|
||||
// string that the caller appends paths to, so strip the trailing
|
||||
// slash here.
|
||||
let s = out.to_string();
|
||||
Some(s.trim_end_matches('/').to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::rewrite_loopback_host;
|
||||
|
||||
#[test]
|
||||
fn rewrites_localhost_keeps_port_and_path() {
|
||||
let out = rewrite_loopback_host(
|
||||
"http://localhost:13131",
|
||||
"http://beast.hanzalova.internal:13131",
|
||||
);
|
||||
assert_eq!(
|
||||
out.as_deref(),
|
||||
Some("http://beast.hanzalova.internal:13131")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rewrites_loopback_with_distinct_inference_port() {
|
||||
let out = rewrite_loopback_host("http://127.0.0.1:8080", "http://beast.lan:13131");
|
||||
assert_eq!(out.as_deref(), Some("http://beast.lan:8080"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn leaves_non_loopback_alone() {
|
||||
let out = rewrite_loopback_host("http://other.host:1234", "http://beast.lan:13131");
|
||||
assert_eq!(out, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn malformed_inference_url_returns_none() {
|
||||
let out = rewrite_loopback_host("not a url", "http://beast.lan:13131");
|
||||
assert_eq!(out, None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@ impl CortexState {
|
||||
models: HashMap::new(),
|
||||
lifecycle_cycles: 0,
|
||||
last_poll: None,
|
||||
discovery: None,
|
||||
activation: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
265
crates/cortex-gateway/tests/aliases.rs
Normal file
265
crates/cortex-gateway/tests/aliases.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
//! Alias resolution: a client request with `model: "helexa/small"`
|
||||
//! routes to the concrete model id (e.g. `Qwen/Qwen3-1.7B`), with the
|
||||
//! proxied request body rewritten so the upstream neuron sees a model
|
||||
//! name that matches its loaded handle.
|
||||
|
||||
mod common;
|
||||
|
||||
use cortex_core::config::{
|
||||
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
|
||||
};
|
||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||
use cortex_gateway::state::CortexState;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Write a `models.toml` with one alias to a unique temp path. Returns
|
||||
/// the path; the file persists for the test process and gets reaped by
|
||||
/// the OS at exit. Using $XDG_RUNTIME_DIR fallback for the temp dir
|
||||
/// keeps the file off shared /tmp on CI without pulling in tempfile.
|
||||
fn write_models_toml(alias: &str, target: &str) -> PathBuf {
|
||||
let contents = format!(
|
||||
r#"
|
||||
[aliases]
|
||||
"{alias}" = "{target}"
|
||||
"#
|
||||
);
|
||||
let mut path = std::env::temp_dir();
|
||||
let pid = std::process::id();
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
path.push(format!("cortex-test-models-{pid}-{now}.toml"));
|
||||
std::fs::write(&path, contents).expect("write temp models.toml");
|
||||
path
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_alias_resolves_in_chat_completions() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let models_path = write_models_toml("helexa/small", "test-model");
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![NeuronEndpoint {
|
||||
name: "mock-node".into(),
|
||||
endpoint: mock_url,
|
||||
}],
|
||||
models_config: models_path.to_string_lossy().to_string(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
|
||||
// Seed the node as healthy with the concrete model loaded under
|
||||
// the target id. The poller doesn't run in this test; we just
|
||||
// populate state manually.
|
||||
{
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||
node.healthy = true;
|
||||
node.models.insert(
|
||||
"test-model".into(),
|
||||
ModelEntry {
|
||||
id: "test-model".into(),
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Sanity: the catalogue actually picked up the alias.
|
||||
assert_eq!(
|
||||
fleet.catalogue.resolve_alias("helexa/small"),
|
||||
"test-model",
|
||||
"alias should resolve to target id"
|
||||
);
|
||||
|
||||
// Spawn the gateway against this fleet.
|
||||
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let gateway_addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let gateway_url = format!("http://{gateway_addr}");
|
||||
|
||||
// Send a chat completion against the alias. The mock backend
|
||||
// echoes back the `model` field it received — so a body whose
|
||||
// model wasn't rewritten would come back as "helexa/small", and a
|
||||
// properly-rewritten one as "test-model".
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gateway_url}/v1/chat/completions"))
|
||||
.json(&json!({
|
||||
"model": "helexa/small",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("gateway should respond");
|
||||
|
||||
assert!(resp.status().is_success(), "gateway returned non-2xx");
|
||||
let body: serde_json::Value = resp.json().await.expect("response is JSON");
|
||||
assert_eq!(
|
||||
body.get("model").and_then(|m| m.as_str()),
|
||||
Some("test-model"),
|
||||
"mock backend should have seen the resolved model id, not the alias"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aliases_surface_in_v1_models() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let models_path = write_models_toml("helexa/small", "test-model");
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![NeuronEndpoint {
|
||||
name: "mock-node".into(),
|
||||
endpoint: mock_url,
|
||||
}],
|
||||
models_config: models_path.to_string_lossy().to_string(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
|
||||
// Seed the target as loaded so the alias's mirrored entry shows
|
||||
// loaded=true.
|
||||
{
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||
node.healthy = true;
|
||||
node.models.insert(
|
||||
"test-model".into(),
|
||||
ModelEntry {
|
||||
id: "test-model".into(),
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: Some(2000),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let gateway_addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let gateway_url = format!("http://{gateway_addr}");
|
||||
|
||||
let resp = reqwest::get(format!("{gateway_url}/v1/models"))
|
||||
.await
|
||||
.expect("gateway should respond");
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let entries = body
|
||||
.get("data")
|
||||
.and_then(|d| d.as_array())
|
||||
.expect("data array");
|
||||
|
||||
// Both the alias and the target should be present.
|
||||
let ids: Vec<&str> = entries
|
||||
.iter()
|
||||
.filter_map(|e| e.get("id").and_then(|v| v.as_str()))
|
||||
.collect();
|
||||
assert!(ids.contains(&"test-model"), "target should be listed");
|
||||
assert!(ids.contains(&"helexa/small"), "alias should be listed");
|
||||
|
||||
// The alias's `loaded` flag and locations should mirror the target.
|
||||
let alias_entry = entries
|
||||
.iter()
|
||||
.find(|e| e.get("id").and_then(|v| v.as_str()) == Some("helexa/small"))
|
||||
.expect("alias entry");
|
||||
assert_eq!(alias_entry.get("loaded"), Some(&json!(true)));
|
||||
let locations = alias_entry
|
||||
.get("locations")
|
||||
.and_then(|l| l.as_array())
|
||||
.expect("locations array");
|
||||
assert_eq!(locations.len(), 1);
|
||||
assert_eq!(
|
||||
locations[0].get("node").and_then(|n| n.as_str()),
|
||||
Some("mock-node")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_alias_falls_through_for_unmapped_model() {
|
||||
// Catalogue has an alias for some-other-thing but the request
|
||||
// model "test-model" isn't an alias; resolution should be a no-op.
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let models_path = write_models_toml("helexa/large", "definitely-not-loaded");
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![NeuronEndpoint {
|
||||
name: "mock-node".into(),
|
||||
endpoint: mock_url,
|
||||
}],
|
||||
models_config: models_path.to_string_lossy().to_string(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
{
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||
node.healthy = true;
|
||||
node.models.insert(
|
||||
"test-model".into(),
|
||||
ModelEntry {
|
||||
id: "test-model".into(),
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let gateway_addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let gateway_url = format!("http://{gateway_addr}");
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{gateway_url}/v1/chat/completions"))
|
||||
.json(&json!({
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(resp.status().is_success());
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(
|
||||
body.get("model").and_then(|m| m.as_str()),
|
||||
Some("test-model")
|
||||
);
|
||||
}
|
||||
@@ -22,6 +22,7 @@ use tokio::net::TcpListener;
|
||||
/// - GET /models/:id/endpoint (returns the inference URL)
|
||||
/// - POST /models/unload (accepts unload requests)
|
||||
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
|
||||
///
|
||||
/// Returns the neuron base URL.
|
||||
pub async fn spawn_mock_neuron() -> String {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
@@ -43,6 +44,7 @@ pub async fn spawn_mock_neuron() -> String {
|
||||
post(|Json(_body): Json<Value>| async { Json(json!({"status": "unloaded"})) }),
|
||||
)
|
||||
.route("/v1/chat/completions", post(mock_chat_completions))
|
||||
.route("/v1/responses", post(mock_responses))
|
||||
.route("/v1/models", get(mock_v1_models));
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -54,7 +56,7 @@ pub async fn spawn_mock_neuron() -> String {
|
||||
|
||||
async fn mock_neuron_list_models() -> Json<Value> {
|
||||
Json(json!([
|
||||
{"id": "test-model", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
||||
{"id": "test-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
||||
]))
|
||||
}
|
||||
|
||||
@@ -92,6 +94,39 @@ async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
|
||||
}))
|
||||
}
|
||||
|
||||
async fn mock_responses(Json(body): Json<Value>) -> Json<Value> {
|
||||
let model = body
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
// Echo the model field back and synthesise a tiny ResponsesResponse.
|
||||
// Mirrors the shape neuron's /v1/responses handler emits so the
|
||||
// gateway test only needs to assert the proxy round-tripped it.
|
||||
Json(json!({
|
||||
"id": "resp-test-001",
|
||||
"object": "response",
|
||||
"created_at": 1700000000_u64,
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"id": "msg-test-001",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "Hello from mock backend",
|
||||
"annotations": []
|
||||
}],
|
||||
"status": "completed"
|
||||
}],
|
||||
"usage": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 10
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Spawns a mock neuron that returns SSE streaming responses for chat completions.
|
||||
pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Duration) -> String {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
@@ -163,6 +198,33 @@ pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Durati
|
||||
|
||||
/// Spawns a mock neuron with a custom models list.
|
||||
pub async fn spawn_mock_neuron_with_models(models_response: Value) -> String {
|
||||
spawn_mock_neuron_with_models_and_health(models_response, default_health_response()).await
|
||||
}
|
||||
|
||||
/// Default `/health` response used by mocks that don't care about the
|
||||
/// activation field — empty devices, no in-flight pre-warm, state=ready.
|
||||
pub fn default_health_response() -> Value {
|
||||
json!({
|
||||
"uptime_secs": 0,
|
||||
"devices": [],
|
||||
"activation": {
|
||||
"state": "ready",
|
||||
"pending": [],
|
||||
"in_progress": null,
|
||||
"completed": [],
|
||||
"failed": []
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Variant of `spawn_mock_neuron_with_models` that also serves a
|
||||
/// `/health` body. Used by tests that drive the gateway's activation
|
||||
/// surface (poller reading /health, /v1/models synthesising Loading
|
||||
/// locations from in_progress / pending).
|
||||
pub async fn spawn_mock_neuron_with_models_and_health(
|
||||
models_response: Value,
|
||||
health_response: Value,
|
||||
) -> String {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let base_url = format!("http://{addr}");
|
||||
@@ -176,6 +238,13 @@ pub async fn spawn_mock_neuron_with_models(models_response: Value) -> String {
|
||||
async move { Json(resp) }
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/health",
|
||||
get(move || {
|
||||
let resp = health_response.clone();
|
||||
async move { Json(resp) }
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/models/{model_id}/endpoint",
|
||||
get(move |Path(_model_id): Path<String>| {
|
||||
|
||||
@@ -12,8 +12,8 @@ use std::sync::Arc;
|
||||
async fn test_poller_discovers_models() {
|
||||
// Mock neuron reports 2 models via /models endpoint (neuron format).
|
||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "model-a", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
||||
{"id": "model-b", "harness": "mistralrs", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
||||
{"id": "model-a", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
||||
{"id": "model-b", "harness": "candle", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
||||
]))
|
||||
.await;
|
||||
|
||||
@@ -63,8 +63,8 @@ async fn test_poller_discovers_models() {
|
||||
#[tokio::test]
|
||||
async fn test_poller_updates_gateway_models_endpoint() {
|
||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "model-x", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||
{"id": "model-y", "harness": "mistralrs", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
||||
{"id": "model-x", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||
{"id": "model-y", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
||||
]))
|
||||
.await;
|
||||
|
||||
@@ -152,8 +152,8 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
|
||||
#[tokio::test]
|
||||
async fn test_poller_removes_stale_models() {
|
||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||
{"id": "drop-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||
{"id": "drop-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||
]))
|
||||
.await;
|
||||
|
||||
@@ -183,7 +183,7 @@ async fn test_poller_removes_stale_models() {
|
||||
|
||||
// New mock with only one model.
|
||||
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||
]))
|
||||
.await;
|
||||
|
||||
@@ -237,3 +237,58 @@ async fn test_poller_removes_stale_models() {
|
||||
assert!(node.models.contains_key("keep-me"));
|
||||
assert!(!node.models.contains_key("drop-me"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_poller_captures_activation_from_health() {
|
||||
// Mock neuron is mid-prewarm: /models reports nothing (the loading
|
||||
// model hasn't been inserted into the harness map yet), but
|
||||
// /health's activation says model-x is in_progress and model-y is
|
||||
// queued behind it.
|
||||
let mock_url = common::spawn_mock_neuron_with_models_and_health(
|
||||
json!([]),
|
||||
json!({
|
||||
"uptime_secs": 30,
|
||||
"devices": [],
|
||||
"activation": {
|
||||
"state": "pre_warming",
|
||||
"pending": ["Qwen/model-y"],
|
||||
"in_progress": "Qwen/model-x",
|
||||
"completed": [],
|
||||
"failed": []
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![NeuronEndpoint {
|
||||
name: "prewarm-node".into(),
|
||||
endpoint: mock_url,
|
||||
}],
|
||||
models_config: "/dev/null".into(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
cortex_gateway::poller::poll_once(&fleet).await;
|
||||
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let node = nodes.get("prewarm-node").unwrap();
|
||||
assert!(node.healthy);
|
||||
// /models was empty — no entries in the per-node model map.
|
||||
assert!(node.models.is_empty());
|
||||
// But /health's activation should be captured.
|
||||
let activation = node
|
||||
.activation
|
||||
.as_ref()
|
||||
.expect("activation should be populated after /health poll");
|
||||
assert_eq!(activation.in_progress.as_deref(), Some("Qwen/model-x"));
|
||||
assert_eq!(activation.pending, vec!["Qwen/model-y".to_string()]);
|
||||
}
|
||||
|
||||
91
crates/cortex-gateway/tests/responses.rs
Normal file
91
crates/cortex-gateway/tests/responses.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
//! Integration tests for the `/v1/responses` proxy route.
|
||||
//!
|
||||
//! The gateway forwards the request body to whichever neuron has the
|
||||
//! model loaded. These tests exercise the routing decision (200 on a
|
||||
//! known model, 404 on an unknown model, 400 on a missing model
|
||||
//! field) and confirm the response body round-trips verbatim.
|
||||
|
||||
mod common;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
/// Happy path: gateway routes a `/v1/responses` request to the neuron
|
||||
/// that has the model loaded, and the neuron's response body
|
||||
/// arrives at the client unchanged.
|
||||
#[tokio::test]
|
||||
async fn test_responses_proxy() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/responses"))
|
||||
.header("content-type", "application/json")
|
||||
.json(&json!({
|
||||
"model": "test-model",
|
||||
"input": "Hi"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(resp.status(), 200);
|
||||
|
||||
let body: serde_json::Value = resp.json().await.expect("valid JSON response");
|
||||
assert_eq!(body["id"], "resp-test-001");
|
||||
assert_eq!(body["object"], "response");
|
||||
assert_eq!(body["model"], "test-model");
|
||||
assert_eq!(body["status"], "completed");
|
||||
assert_eq!(
|
||||
body["output"][0]["content"][0]["text"],
|
||||
"Hello from mock backend"
|
||||
);
|
||||
// Usage shape is the Responses-specific (input/output_tokens),
|
||||
// not the chat-completions one (prompt/completion_tokens). Asserts
|
||||
// the proxy didn't accidentally route through the wrong handler.
|
||||
assert_eq!(body["usage"]["total_tokens"], 10);
|
||||
assert!(body["usage"].get("input_tokens").is_some());
|
||||
}
|
||||
|
||||
/// A request that targets a model not present in the catalogue gets
|
||||
/// 404 from the router. This matches the chat-completions handler's
|
||||
/// behaviour — same error path, same status code, so a client can
|
||||
/// share retry logic across the two routes.
|
||||
#[tokio::test]
|
||||
async fn test_responses_model_not_found() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/responses"))
|
||||
.json(&json!({
|
||||
"model": "not-in-catalogue",
|
||||
"input": "Hi"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 404);
|
||||
}
|
||||
|
||||
/// A request body without a `model` field can't be routed; the
|
||||
/// gateway returns 400 before reaching a backend. Same as the
|
||||
/// chat-completions handler — extracted via the same `extract_model`
|
||||
/// helper.
|
||||
#[tokio::test]
|
||||
async fn test_responses_missing_model_field() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/responses"))
|
||||
.json(&json!({
|
||||
"input": "Hi"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 400);
|
||||
}
|
||||
@@ -51,18 +51,18 @@ async fn test_streaming_sse_passthrough() {
|
||||
}
|
||||
|
||||
assert!(
|
||||
chunks.len() >= chunk_count + 1,
|
||||
"expected at least {} chunks (got {}): {:?}",
|
||||
chunk_count + 1,
|
||||
chunks.len() > chunk_count,
|
||||
"expected more than {} chunks (got {}): {:?}",
|
||||
chunk_count,
|
||||
chunks.len(),
|
||||
chunks,
|
||||
);
|
||||
|
||||
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
||||
|
||||
for i in 0..chunk_count {
|
||||
for (i, chunk) in chunks.iter().enumerate().take(chunk_count) {
|
||||
let chunk_json: serde_json::Value =
|
||||
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
|
||||
serde_json::from_str(chunk).expect("chunk should be valid JSON");
|
||||
assert_eq!(
|
||||
chunk_json["choices"][0]["delta"]["content"],
|
||||
format!("token{i}")
|
||||
|
||||
48
crates/helexa-acp/Cargo.toml
Normal file
48
crates/helexa-acp/Cargo.toml
Normal file
@@ -0,0 +1,48 @@
|
||||
[package]
|
||||
name = "helexa-acp"
|
||||
version = "0.1.16"
|
||||
edition = "2024"
|
||||
license = "Apache-2.0"
|
||||
repository = "https://git.lair.cafe/helexa/cortex"
|
||||
description = """
|
||||
Agent Client Protocol bridge for the helexa self-hosted LLM stack.
|
||||
Speaks ACP to ACP-compatible editor clients (Zed, etc.) and forwards
|
||||
the conversation to any OpenAI-compatible HTTP endpoint — defaulting
|
||||
to cortex (helexa's reverse-proxy / fleet gateway).
|
||||
"""
|
||||
|
||||
# This crate is intentionally self-contained — no dependencies on other
|
||||
# workspace crates (cortex-core, cortex-gateway, neuron). The goal is
|
||||
# a painless migration to a dedicated GitHub repo in the future if the
|
||||
# project grows beyond helexa's needs. All deps are crates.io.
|
||||
[dependencies]
|
||||
# `unstable_session_model` flips on the SessionModelState type and the
|
||||
# session/set_model RPC the model-picker dropdown in Zed needs. The
|
||||
# feature is upstream-marked unstable; we accept that risk because the
|
||||
# model picker is core UX and the alternative (rolling our own
|
||||
# extension method) drifts further from spec each time it moves.
|
||||
agent-client-protocol = { version = "0.12", features = ["unstable_session_model"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "io-util", "process", "signal"] }
|
||||
reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"], default-features = false }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
toml = "0.8"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
anyhow = "1"
|
||||
thiserror = "2"
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
tokio-stream = "0.1"
|
||||
tokio-util = { version = "0.7", features = ["rt"] }
|
||||
eventsource-stream = "0.2"
|
||||
async-stream = "0.3"
|
||||
url = { version = "2", features = ["serde"] }
|
||||
# Already transitively pulled via the ACP SDK; declared directly so we
|
||||
# can format ISO 8601 timestamps for `SessionInfo.updated_at` in the
|
||||
# session/list response.
|
||||
chrono = { version = "0.4", default-features = false, features = ["std"] }
|
||||
|
||||
[[bin]]
|
||||
name = "helexa-acp"
|
||||
path = "src/main.rs"
|
||||
546
crates/helexa-acp/README.md
Normal file
546
crates/helexa-acp/README.md
Normal file
@@ -0,0 +1,546 @@
|
||||
# helexa-acp
|
||||
|
||||
ACP (Agent Client Protocol) bridge for editors like
|
||||
[Zed](https://zed.dev). Lets you point your editor's agent panel at
|
||||
**any combination** of OpenAI-compatible, OpenAI Responses, and
|
||||
Anthropic Messages endpoints — public APIs, private LAN deployments,
|
||||
local Ollama / LM Studio — and switch between them per session via a
|
||||
model dropdown.
|
||||
|
||||
The "missing ACP binary" for users who don't want to be locked into
|
||||
one vendor's agent client.
|
||||
|
||||
```
|
||||
┌───────────────────────────────────┐
|
||||
│ Zed (or any ACP editor client) │
|
||||
└────────────┬──────────────────────┘
|
||||
│ stdio JSON-RPC (ACP)
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ helexa-acp │ ← one binary, multi-endpoint
|
||||
└─────┬───────────┘
|
||||
│ HTTP / SSE
|
||||
┌────────┼─────────────┬──────────────┬──────────────┐
|
||||
▼ ▼ ▼ ▼ ▼
|
||||
cortex/ OpenAI Anthropic OpenRouter LM Studio
|
||||
neuron Responses Messages
|
||||
(self- (gpt-5,…) (Claude)
|
||||
hosted)
|
||||
```
|
||||
|
||||
## What it does
|
||||
|
||||
- **Speaks ACP** over stdio to editor clients (Zed today; any future
|
||||
ACP client tomorrow).
|
||||
- **Multi-endpoint** — one config file lists every LLM endpoint
|
||||
you want available; pick one per session via the model dropdown
|
||||
(`endpoint:model` selector).
|
||||
- **Three wire formats**: `openai-chat` (the broadly compatible
|
||||
default), `openai-responses` (newer OpenAI surface), and
|
||||
`anthropic-messages` (Claude). Each is a separate provider impl
|
||||
in `src/provider/`; adding a fourth (Gemini, Ollama native, …) is
|
||||
one file plus a `WireApi` enum variant.
|
||||
- **Built-in tools**: `read_file`, `write_file`, `edit_file`,
|
||||
`list_dir`, `bash`. Permission-gated by default; the editor user
|
||||
approves writes/shell per-call.
|
||||
- **Three session modes**: Default (gated), Bypass Permissions
|
||||
(auto-allow), and Plan (write-only-to-plan-dir, no shell).
|
||||
- **Vision** — drag-drop images into the agent panel against any
|
||||
vision-capable model.
|
||||
- **Session resume** — multi-day conversations survive editor
|
||||
restarts via on-disk transcript persistence.
|
||||
- **Context compaction** — rolling history stays inside the model's
|
||||
context window automatically so long sessions on small-context
|
||||
local models don't fall over.
|
||||
|
||||
## Install
|
||||
|
||||
### From source
|
||||
|
||||
```sh
|
||||
git clone https://git.lair.cafe/helexa/cortex.git
|
||||
cd cortex
|
||||
cargo install --path crates/helexa-acp
|
||||
# Binary lands at ~/.cargo/bin/helexa-acp
|
||||
```
|
||||
|
||||
### Pre-built RPM (Fedora 43)
|
||||
|
||||
```sh
|
||||
dnf copr enable helexa/helexa
|
||||
dnf install helexa-acp
|
||||
```
|
||||
|
||||
The COPR project bundles helexa-acp alongside the cortex gateway
|
||||
and helexa-neuron flavours; install only the package(s) you need.
|
||||
|
||||
## Quick start
|
||||
|
||||
The fastest path: env-var single-endpoint config.
|
||||
|
||||
```sh
|
||||
export HELEXA_ACP_BASE_URL=http://hanzalova.internal:31313/v1
|
||||
export HELEXA_ACP_MODEL=Qwen/Qwen3.6-27B
|
||||
helexa-acp # speaks ACP over stdin/stdout; not interactive
|
||||
```
|
||||
|
||||
Then in Zed (`~/.config/zed/settings.json`):
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"agent_servers": {
|
||||
"helexa": {
|
||||
"command": "helexa-acp",
|
||||
"args": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Restart Zed → open the agent panel → pick "helexa" → start
|
||||
chatting. Tool calls (file reads, writes, bash) prompt for
|
||||
permission per-call in Default mode.
|
||||
|
||||
That's the minimum. The full config story below is what unlocks
|
||||
the multi-endpoint dropdown.
|
||||
|
||||
## Multi-endpoint config
|
||||
|
||||
Copy `helexa-acp.example.toml` from this repo to
|
||||
`$XDG_CONFIG_HOME/helexa-acp/config.toml` (typically
|
||||
`~/.config/helexa-acp/config.toml`) and edit:
|
||||
|
||||
```toml
|
||||
default_endpoint = "helexa"
|
||||
|
||||
[[endpoints]]
|
||||
name = "helexa"
|
||||
base_url = "http://hanzalova.internal:31313/v1"
|
||||
wire_api = "openai-chat"
|
||||
default_model = "Qwen/Qwen3.6-27B"
|
||||
max_tokens = 8192
|
||||
context_window = 32768
|
||||
|
||||
[[endpoints]]
|
||||
name = "openrouter"
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
wire_api = "openai-chat"
|
||||
api_key_env = "OPENROUTER_API_KEY"
|
||||
default_model = "anthropic/claude-opus-4"
|
||||
|
||||
[[endpoints]]
|
||||
name = "anthropic"
|
||||
base_url = "https://api.anthropic.com/v1"
|
||||
wire_api = "anthropic-messages"
|
||||
api_key_env = "ANTHROPIC_API_KEY"
|
||||
default_model = "claude-opus-4"
|
||||
```
|
||||
|
||||
Restart Zed. The model dropdown lists every model from every
|
||||
configured endpoint with the `endpoint:model` selector
|
||||
(`helexa:Qwen/Qwen3.6-27B`, `openrouter:anthropic/claude-opus-4`,
|
||||
…). Switch mid-session; the next prompt routes to the new endpoint.
|
||||
|
||||
When only one endpoint is configured the prefix is dropped (model
|
||||
ids appear bare).
|
||||
|
||||
### Selector syntax
|
||||
|
||||
The `model` field on every internal request is parsed as
|
||||
`<endpoint>:<model>`:
|
||||
|
||||
- `openrouter:gpt-4o` → routes to the `openrouter` endpoint,
|
||||
model `gpt-4o`.
|
||||
- `helexa/large` → no colon → falls through to whichever endpoint
|
||||
is named in `default_endpoint`, model `helexa/large`.
|
||||
- `:gpt-5` → leading colon → also falls through to default.
|
||||
|
||||
## Endpoint cookbook
|
||||
|
||||
Copy-pasteable blocks. Mix and match.
|
||||
|
||||
### cortex / neuron (self-hosted)
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "helexa"
|
||||
base_url = "http://hanzalova.internal:31313/v1"
|
||||
wire_api = "openai-chat"
|
||||
default_model = "Qwen/Qwen3.6-27B"
|
||||
max_tokens = 8192
|
||||
context_window = 32768
|
||||
```
|
||||
|
||||
Use `openai-responses` instead of `openai-chat` once cortex 0.1.16+
|
||||
is deployed and you want the Responses API surface (vision item
|
||||
shape, structured reasoning items, etc.).
|
||||
|
||||
### OpenAI directly
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "openai"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
wire_api = "openai-responses"
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
default_model = "gpt-5"
|
||||
```
|
||||
|
||||
`openai-responses` is the right choice for current OpenAI models;
|
||||
`openai-chat` works against legacy GPT-3.5/4 deployments and
|
||||
anything labelled "chat completions".
|
||||
|
||||
### Anthropic directly
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "anthropic"
|
||||
base_url = "https://api.anthropic.com/v1"
|
||||
wire_api = "anthropic-messages"
|
||||
api_key_env = "ANTHROPIC_API_KEY"
|
||||
default_model = "claude-opus-4"
|
||||
```
|
||||
|
||||
helexa-acp sends `x-api-key` + `anthropic-version: 2023-06-01`
|
||||
automatically. The `api_key_env` indirection keeps your key out of
|
||||
the config file.
|
||||
|
||||
### OpenRouter (multi-vendor proxy)
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "openrouter"
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
wire_api = "openai-chat"
|
||||
api_key_env = "OPENROUTER_API_KEY"
|
||||
default_model = "anthropic/claude-opus-4"
|
||||
```
|
||||
|
||||
OpenRouter speaks OpenAI-compat for every model it fronts, so
|
||||
`openai-chat` is the right wire format regardless of the
|
||||
underlying vendor.
|
||||
|
||||
### LM Studio (local)
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "lmstudio"
|
||||
base_url = "http://localhost:1234/v1"
|
||||
wire_api = "openai-chat"
|
||||
default_model = "auto"
|
||||
```
|
||||
|
||||
LM Studio's "auto" model id picks whatever's loaded. Same shape
|
||||
works for Ollama in compat mode (`http://localhost:11434/v1`) and
|
||||
vLLM.
|
||||
|
||||
### Multiple cortex deployments
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "lan"
|
||||
base_url = "http://hanzalova.internal:31313/v1"
|
||||
wire_api = "openai-chat"
|
||||
default_model = "Qwen/Qwen3.6-27B"
|
||||
|
||||
[[endpoints]]
|
||||
name = "cloud"
|
||||
base_url = "https://cortex.example.com/v1"
|
||||
wire_api = "openai-chat"
|
||||
api_key_env = "CLOUD_CORTEX_KEY"
|
||||
default_model = "Qwen/Qwen3-VL-8B"
|
||||
```
|
||||
|
||||
Use the `endpoint:model` selector to switch between them mid-session.
|
||||
|
||||
## Zed setup
|
||||
|
||||
`~/.config/zed/settings.json`:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"agent_servers": {
|
||||
"helexa": {
|
||||
"command": "helexa-acp"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Optional environment overrides for the binary:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"agent_servers": {
|
||||
"helexa": {
|
||||
"command": "helexa-acp",
|
||||
"env": {
|
||||
"HELEXA_ACP_LOG_FILE": "/tmp/helexa-acp.log",
|
||||
"RUST_LOG": "helexa_acp=debug"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`HELEXA_ACP_LOG_FILE` is the one you actually want — Zed doesn't
|
||||
surface the agent's stderr, so without that env var debug output is
|
||||
invisible. Point it at a file you can `tail -f`.
|
||||
|
||||
After restarting Zed: ⌘+? (or wherever your "Open Agent Panel"
|
||||
binding is) → select "helexa" → the model dropdown populates from
|
||||
your config → start prompting.
|
||||
|
||||
## Modes
|
||||
|
||||
Three session modes ship; the user picks via Zed's mode dropdown
|
||||
on the agent panel.
|
||||
|
||||
| Mode | Reads | Writes | Bash | Permission prompts |
|
||||
|------|-------|--------|------|--------------------|
|
||||
| **Default** | ✓ | with prompt | with prompt | per call |
|
||||
| **Bypass Permissions** | ✓ | ✓ | ✓ | never |
|
||||
| **Plan** | ✓ | only into plan dir | disabled | never (plan-dir writes auto-allow) |
|
||||
|
||||
### Default
|
||||
|
||||
Reads are always allowed (`read_file`, `list_dir` are
|
||||
unrestricted). Writes and shell commands prompt the user before
|
||||
running. The intended baseline for any session where the agent
|
||||
might do something you'd rather review first.
|
||||
|
||||
### Bypass Permissions
|
||||
|
||||
Auto-allow every tool call. Use for agentic loops you trust — bulk
|
||||
edits across many files, scripted workflows, prepared session
|
||||
templates. Never for code the agent hasn't seen before.
|
||||
|
||||
### Plan
|
||||
|
||||
The "draft an implementation plan before you write code" mode.
|
||||
Available tools:
|
||||
|
||||
- `read_file`, `list_dir`: unrestricted (read the codebase).
|
||||
- `write_file`, `edit_file`: allowed *only* under
|
||||
`$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. Any path
|
||||
outside that returns "plan mode: writes are restricted to …"
|
||||
back to the model so it self-corrects.
|
||||
- `bash`: disabled outright. Returns "plan mode: shell execution
|
||||
is disabled" if attempted.
|
||||
|
||||
When the plan is complete, the model presents a 3-option menu:
|
||||
|
||||
1. **Bypass Permissions** — implement the plan now, no prompts.
|
||||
2. **Default** — implement now with per-tool prompts.
|
||||
3. **Plan** (stay here) — refine the plan with more guidance.
|
||||
|
||||
Switch the mode dropdown to your preference and reply to proceed.
|
||||
|
||||
## Tools
|
||||
|
||||
Five tools, defined in `src/tools.rs`:
|
||||
|
||||
| Tool | Args | Gated in Default? |
|
||||
|------|------|-------------------|
|
||||
| `read_file` | `path`, `line?`, `limit?` | no |
|
||||
| `list_dir` | `path` | no |
|
||||
| `write_file` | `path`, `content` | yes |
|
||||
| `edit_file` | `path`, `old_text`, `new_text` | yes |
|
||||
| `bash` | `command`, `cwd?` | yes |
|
||||
|
||||
### Path handling
|
||||
|
||||
`~`, `~/`, `$HOME`, and `$HOME/` are expanded server-side before
|
||||
the path reaches ACP or local fs. Lets the model emit
|
||||
`~/git/repo/file.rs` and have it Just Work.
|
||||
|
||||
`read_file` first tries the editor's filesystem (ACP's
|
||||
`fs/read_text_file` — respects open buffers, workspace overlays,
|
||||
etc.). If that fails — typically because the path is outside Zed's
|
||||
workspace boundary — it falls back to `std::fs::read_to_string`.
|
||||
This lets the agent pull in shared material like
|
||||
`~/git/architecture/generic.md` from a different project's
|
||||
session.
|
||||
|
||||
The fallback is logged at warn level so you can see when it kicks
|
||||
in.
|
||||
|
||||
### Tool dispatch
|
||||
|
||||
Tool descriptions reach the model through a Qwen3 Hermes-format
|
||||
`# Tools` block injected into the system prompt — cortex/neuron
|
||||
pass the OpenAI `tools` request field through to the encoder
|
||||
unread, so we work the model into emitting `<tool_call>{json}</tool_call>`
|
||||
markers it then parses out of the content stream. This applies to
|
||||
the helexa wire format; OpenAI / Anthropic endpoints with native
|
||||
tool support would use their own paths once they're wired in.
|
||||
|
||||
The parser is tolerant: malformed JSON (trailing braces, missing
|
||||
`name`, name nested in `arguments`) gets a repair pass; if that
|
||||
fails the call surfaces as a "Malformed tool call" card in Zed and
|
||||
the model gets a synthetic error result so it can self-correct.
|
||||
|
||||
## Session resume
|
||||
|
||||
helexa-acp persists every session to
|
||||
`$XDG_DATA_HOME/helexa-acp/sessions/<id>.json`. Zed's `session/list`
|
||||
RPC asks helexa-acp to enumerate them on workspace open;
|
||||
`session/load` rehydrates and replays the transcript as
|
||||
`session/update` notifications so the agent panel renders the
|
||||
prior conversation.
|
||||
|
||||
Behaviour:
|
||||
|
||||
- Persisted per-round, so a mid-turn agent stall (long bash, wedged
|
||||
ACP roundtrip) doesn't lose earlier rounds.
|
||||
- Survives editor restart and the helexa-acp binary upgrading
|
||||
between versions.
|
||||
- Project-scoped: only sessions whose `cwd` matches the workspace
|
||||
are listed.
|
||||
|
||||
To wipe history: `rm -rf $XDG_DATA_HOME/helexa-acp/sessions/`.
|
||||
|
||||
## Context compaction
|
||||
|
||||
When an endpoint sets `context_window`, helexa-acp projects the
|
||||
rolling history into a token budget before each request — old
|
||||
`ToolResult` content (read_file payloads are the worst offenders)
|
||||
gets elided to one-line markers, preserving `tool_call_id` pairing
|
||||
so the wire schema stays valid.
|
||||
|
||||
System prompts, user turns, and the most recent ~4 messages are
|
||||
never elided. The full history stays on disk; compaction is a
|
||||
per-request projection, not a destructive edit.
|
||||
|
||||
Set `context_window = 32768` for a 32 K Qwen3, `131072` for a
|
||||
modern Claude, etc. With `max_tokens` also set, the budget is
|
||||
`context_window - max_tokens - 512_safety`.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "default endpoint 'helexa' has no usable provider — check config"
|
||||
|
||||
The named default endpoint failed to construct. Usually:
|
||||
|
||||
- `api_key_env` references a variable that isn't set in the env
|
||||
Zed launched helexa-acp with.
|
||||
- The TOML's `wire_api` is misspelled (only `openai-chat`,
|
||||
`openai-responses`, `anthropic-messages` are accepted).
|
||||
|
||||
Test by running `helexa-acp` directly from a shell — startup
|
||||
errors land on stderr.
|
||||
|
||||
### Model dropdown is empty
|
||||
|
||||
Each provider's `list_models` failed at startup. Look at
|
||||
`HELEXA_ACP_LOG_FILE` for "list_models failed; this endpoint's
|
||||
models won't appear in the picker". Likely the endpoint URL is
|
||||
wrong, the API key is invalid, or the upstream `/v1/models`
|
||||
endpoint isn't responding.
|
||||
|
||||
The agent still works against `default_model` even when the
|
||||
dropdown is empty — list-models is for picking, not routing.
|
||||
|
||||
### "prompt_too_long" / agent stalls mid-conversation
|
||||
|
||||
You hit the model's context window. Set `context_window` on the
|
||||
endpoint and helexa-acp will compact before sending. The log line
|
||||
`context compaction applied` confirms it's running; if it fires
|
||||
but the upstream still rejects, the compaction heuristic
|
||||
under-counted and the budget needs tuning down.
|
||||
|
||||
### Reading files outside the workspace returns "not found"
|
||||
|
||||
Zed's `fs/read_text_file` is workspace-scoped. helexa-acp falls
|
||||
back to local `std::fs` automatically when that fails — look for
|
||||
`fs/read_text_file failed; falling back to local std::fs` in the
|
||||
log. If even local read fails, the file genuinely doesn't exist
|
||||
or the user process lacks permissions.
|
||||
|
||||
### Tool calls render as text instead of structured cards
|
||||
|
||||
The model is emitting `<tool_call>` markers that the parser can't
|
||||
decode. Two common causes:
|
||||
|
||||
1. The system prompt isn't reaching the model (cortex/neuron's
|
||||
tool-block injection didn't fire). Confirm with
|
||||
`RUST_LOG=helexa_acp=debug` and look at the outgoing
|
||||
`POST /chat/completions` body.
|
||||
2. The model itself is too small / undertrained to follow the
|
||||
Hermes format reliably. helexa-acp has shape-based name
|
||||
inference and JSON repair, but there's a floor below which
|
||||
nothing helps.
|
||||
|
||||
### Plan-mode writes refused even inside the plan dir
|
||||
|
||||
The path comparison is byte-for-byte. If the model emits a path
|
||||
with `~` and the plan_dir has the expanded form, expansion runs
|
||||
*before* the comparison — but resolved-vs-symlinked-path
|
||||
mismatches can still bite. The error message names the attempted
|
||||
path and the expected prefix so you can compare directly.
|
||||
|
||||
## Architecture
|
||||
|
||||
Source layout under `crates/helexa-acp/src/`:
|
||||
|
||||
| File | Responsibility |
|
||||
|------|----------------|
|
||||
| `main.rs` | tokio + Stdio transport. Builds providers, hands off to `agent::Agent` |
|
||||
| `config.rs` | TOML + env-fallback config, endpoint resolver |
|
||||
| `agent.rs` | ACP handlers (initialize, session/new, session/prompt, session/cancel, session/set_mode, session/set_model, session/load, session/list), prompt loop with tool-call recursion |
|
||||
| `session.rs` | Per-session state map (Arc<RwLock<HashMap<…>>>) |
|
||||
| `store.rs` | On-disk session persistence, plan-dir resolution |
|
||||
| `prompt.rs` | System-prompt assembly, plan-mode addendum |
|
||||
| `tools.rs` | Tool schemas + shape-based name inference |
|
||||
| `tool_runner.rs` | Dispatch a single tool call through ACP client RPCs; permission gate |
|
||||
| `qwen3.rs` | Qwen3 Hermes tool-format parser (`<tool_call>` / `<think>` markers) |
|
||||
| `compaction.rs` | Token-budget compaction for the rolling history |
|
||||
| `path_util.rs` | `~` / `$HOME` expansion shared across every path-taking tool |
|
||||
| `provider/openai_chat.rs` | OpenAI chat completions provider |
|
||||
| `provider/openai_responses.rs` | OpenAI Responses API provider |
|
||||
| `provider/anthropic_messages.rs` | Anthropic Messages API provider |
|
||||
|
||||
### Adding a new wire format
|
||||
|
||||
1. New file under `src/provider/` implementing the `Provider`
|
||||
trait (encoder + SSE decoder).
|
||||
2. Add a `WireApi` variant in `config.rs`.
|
||||
3. Wire it into `build_provider` in `main.rs`.
|
||||
4. Done — every other module is wire-format-agnostic.
|
||||
|
||||
### Concurrency
|
||||
|
||||
- `Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>` —
|
||||
per-session mutex so concurrent requests across sessions don't
|
||||
contend; the map's RwLock is read-mostly.
|
||||
- Every tool call dispatched serially within a session (parallel
|
||||
dispatch would require Zed to handle interleaved permission
|
||||
prompts).
|
||||
- Provider streams are back-pressured by the consumer (bounded
|
||||
mpsc channels).
|
||||
|
||||
### Self-contained
|
||||
|
||||
The crate has no workspace-internal dependencies (no
|
||||
`cortex-core`, no `cortex-gateway`). Migration to a dedicated
|
||||
GitHub repo for cross-platform CI / cargo-dist binaries is
|
||||
Cargo.toml-only.
|
||||
|
||||
## Status
|
||||
|
||||
- Stages 1–6 shipped: scaffold, agent loop, tools, modes, session
|
||||
resume, image input, model picker, three wire formats.
|
||||
- Stage 8 (RPM + multi-platform CI) tracked in the canonical plan;
|
||||
Linux x86_64 RPM ships today via the cortex monorepo's Gitea
|
||||
Actions.
|
||||
|
||||
## Contributing
|
||||
|
||||
Repository: https://git.lair.cafe/helexa/cortex (`crates/helexa-acp/`).
|
||||
Issues / PRs welcome. The canonical staged plan is in
|
||||
`~/.claude/plans/plan-the-per-device-worker-abstract-micali.md` on
|
||||
the maintainer's machine; the substages 3a–3e and 6a/6b that the
|
||||
canonical plan didn't anticipate are documented in commit messages.
|
||||
|
||||
CI: `cargo fmt --check --all`, `cargo clippy --workspace -- -D
|
||||
warnings`, `cargo test --workspace` must all pass before merge.
|
||||
1820
crates/helexa-acp/src/agent.rs
Normal file
1820
crates/helexa-acp/src/agent.rs
Normal file
File diff suppressed because it is too large
Load Diff
425
crates/helexa-acp/src/compaction.rs
Normal file
425
crates/helexa-acp/src/compaction.rs
Normal file
@@ -0,0 +1,425 @@
|
||||
//! Rolling-conversation compaction for small-context local models.
|
||||
//!
|
||||
//! The tool-call loop in [`crate::agent`] grows the message vec it
|
||||
//! sends upstream every round. On a frontier model that's fine; on a
|
||||
//! 32 K Qwen3 the first few `read_file` results can push the prompt
|
||||
//! past the model's context window, at which point cortex/neuron
|
||||
//! refuses with `prompt_too_long` and the whole turn dies. Long-form
|
||||
//! local agents are unusable without something here.
|
||||
//!
|
||||
//! Strategy (intentionally simple — no LLM-summarization round-trip,
|
||||
//! no tokenizer dependency):
|
||||
//!
|
||||
//! 1. **Protect** the things the model cannot reason without:
|
||||
//! - The system prompt (idx 0).
|
||||
//! - Every `Role::User` turn (the user's intent — irreplaceable).
|
||||
//! - The last [`KEEP_TAIL`] messages (most recent rounds stay
|
||||
//! verbatim so the model can keep working on what it just
|
||||
//! observed).
|
||||
//! 2. **Elide** older `Role::Assistant` prose and older `Role::Tool`
|
||||
//! result content. The structure stays — `tool_call_id`s, tool
|
||||
//! names, and argument JSON survive intact — so OpenAI's strict
|
||||
//! `tool_calls` ↔ `tool` pairing schema remains satisfied. Only
|
||||
//! the *payload* shrinks to a one-line marker.
|
||||
//! 3. Walk oldest→newest, recomputing the budget after each elision.
|
||||
//! Stop as soon as we fit; we don't compact more than necessary.
|
||||
//! 4. If we still exceed budget after eliding everything we're
|
||||
//! allowed to, return what we have. The upstream will surface a
|
||||
//! `prompt_too_long` error and the user can intervene; that's
|
||||
//! better than silently dropping content the model needs.
|
||||
//!
|
||||
//! Token estimation uses a `chars / 3.5` heuristic — conservative
|
||||
//! (over-estimates tokens slightly) so we compact a touch early
|
||||
//! rather than a touch late.
|
||||
|
||||
use crate::provider::{Message, MessageContent, MessagePart, Role};
|
||||
|
||||
/// Most-recent N messages that are never elided. Roughly "the
|
||||
/// current tool round in flight" — assistant turn that called the
|
||||
/// tools + each tool result + a bit of slack.
|
||||
const KEEP_TAIL: usize = 4;
|
||||
|
||||
/// Below this content size we don't bother eliding — the savings
|
||||
/// don't outweigh the loss of detail. Roughly 60–80 tokens.
|
||||
const ELIDE_MIN_CHARS: usize = 256;
|
||||
|
||||
/// Roughly tokens-per-character for English + code mixed in. The
|
||||
/// actual per-tokenizer ratio varies (GPT-4o ≈ 4 chars/token on
|
||||
/// English prose, ≈ 3 chars/token on code-heavy text). We pick a
|
||||
/// value on the conservative end so the budget check fires *before*
|
||||
/// the upstream tokenizer says no.
|
||||
const CHARS_PER_TOKEN: f32 = 3.5;
|
||||
|
||||
/// Per-message envelope overhead (role + JSON framing). Comes out
|
||||
/// to a few tokens; tiny but it adds up across long histories.
|
||||
const ENVELOPE_TOKENS: usize = 8;
|
||||
|
||||
/// Rough per-image token cost used by the budget estimator. Real
|
||||
/// vision tokenizers vary widely (256–1024 tokens for typical
|
||||
/// resolutions on Qwen3-VL, OpenAI's `low`/`high` detail toggles
|
||||
/// pick between ~85 and ~1000+). 512 is a defensible middle that
|
||||
/// keeps compaction from treating images as free.
|
||||
const IMAGE_TOKENS_APPROX: usize = 512;
|
||||
|
||||
/// Stats reported back from [`compact_to_budget`] for the caller to
|
||||
/// log. The numbers are estimates (see [`estimate_tokens`]), so
|
||||
/// don't compare them to upstream-reported token counts as if they
|
||||
/// were exact.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct CompactionStats {
|
||||
/// Estimated tokens in the input messages.
|
||||
pub original_tokens: usize,
|
||||
/// Estimated tokens after compaction. Equal to `original_tokens`
|
||||
/// when no compaction was needed.
|
||||
pub final_tokens: usize,
|
||||
/// Number of messages whose content was elided. Zero is the
|
||||
/// hot path (nothing to do).
|
||||
pub elided_messages: usize,
|
||||
}
|
||||
|
||||
impl CompactionStats {
|
||||
fn unchanged(tokens: usize) -> Self {
|
||||
Self {
|
||||
original_tokens: tokens,
|
||||
final_tokens: tokens,
|
||||
elided_messages: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Approximate token count for one message. Sums the textual
|
||||
/// payload's chars, divides by [`CHARS_PER_TOKEN`], and adds an
|
||||
/// envelope constant. Cheap (no allocation) so safe to call once per
|
||||
/// message per round.
|
||||
pub fn estimate_tokens(msg: &Message) -> usize {
|
||||
let chars = match &msg.content {
|
||||
MessageContent::Text { text } => text.len(),
|
||||
MessageContent::MultiPart { parts } => parts
|
||||
.iter()
|
||||
.map(|p| match p {
|
||||
MessagePart::Text { text } => text.len(),
|
||||
// Each image is one block in the context window; the
|
||||
// upstream tokenizer handles the real cost (and it
|
||||
// varies wildly by model — Qwen3-VL uses ~256-1024
|
||||
// tokens per image depending on size). Take a
|
||||
// middle estimate so the budget tracker doesn't
|
||||
// pretend images are free.
|
||||
MessagePart::Image(_) => IMAGE_TOKENS_APPROX * CHARS_PER_TOKEN as usize,
|
||||
})
|
||||
.sum(),
|
||||
MessageContent::ToolCalls { text, calls } => {
|
||||
let txt = text.as_deref().map(|s| s.len()).unwrap_or(0);
|
||||
let calls_size: usize = calls
|
||||
.iter()
|
||||
.map(|c| c.name.len() + c.arguments.len() + c.id.len())
|
||||
.sum();
|
||||
txt + calls_size
|
||||
}
|
||||
MessageContent::ToolResult {
|
||||
tool_call_id,
|
||||
content,
|
||||
} => tool_call_id.len() + content.len(),
|
||||
};
|
||||
((chars as f32 / CHARS_PER_TOKEN) as usize) + ENVELOPE_TOKENS
|
||||
}
|
||||
|
||||
/// Sum of [`estimate_tokens`] across all messages.
|
||||
pub fn total_tokens(messages: &[Message]) -> usize {
|
||||
messages.iter().map(estimate_tokens).sum()
|
||||
}
|
||||
|
||||
/// Project `messages` into a vec whose estimated token count fits in
|
||||
/// `budget` tokens. Returns the projection plus stats about what
|
||||
/// was done. When the input already fits, the projection is a clone
|
||||
/// of the input and stats report zero elisions.
|
||||
///
|
||||
/// See module docs for the strategy and protected set.
|
||||
pub fn compact_to_budget(messages: &[Message], budget: usize) -> (Vec<Message>, CompactionStats) {
|
||||
let original = total_tokens(messages);
|
||||
if original <= budget {
|
||||
return (messages.to_vec(), CompactionStats::unchanged(original));
|
||||
}
|
||||
|
||||
let mut out = messages.to_vec();
|
||||
let len = out.len();
|
||||
let tail_start = len.saturating_sub(KEEP_TAIL);
|
||||
let mut elided = 0usize;
|
||||
|
||||
// Two passes. First pass: ToolResult contents (largest savings
|
||||
// per elision — read_file payloads land here). Second pass: long
|
||||
// Assistant prose. We don't interleave because eliding a long
|
||||
// assistant turn before a really old read_file would do less
|
||||
// good per elision; oldest-first ordering is enforced *within*
|
||||
// each pass instead.
|
||||
for pass in 0..2 {
|
||||
for i in 1..tail_start {
|
||||
if matches!(out[i].role, Role::User) {
|
||||
continue;
|
||||
}
|
||||
let target_pass_2 = matches!(
|
||||
&out[i].content,
|
||||
MessageContent::Text { .. } | MessageContent::ToolCalls { .. }
|
||||
);
|
||||
let target_pass_1 = matches!(&out[i].content, MessageContent::ToolResult { .. });
|
||||
let in_pass = (pass == 0 && target_pass_1) || (pass == 1 && target_pass_2);
|
||||
if !in_pass {
|
||||
continue;
|
||||
}
|
||||
if elide_in_place(&mut out[i]) {
|
||||
elided += 1;
|
||||
if total_tokens(&out) <= budget {
|
||||
let final_tokens = total_tokens(&out);
|
||||
return (
|
||||
out,
|
||||
CompactionStats {
|
||||
original_tokens: original,
|
||||
final_tokens,
|
||||
elided_messages: elided,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let final_tokens = total_tokens(&out);
|
||||
(
|
||||
out,
|
||||
CompactionStats {
|
||||
original_tokens: original,
|
||||
final_tokens,
|
||||
elided_messages: elided,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Shrink one message's payload while keeping its structural role
|
||||
/// (so tool_call_id pairing survives). Returns `true` when the
|
||||
/// message changed.
|
||||
///
|
||||
/// - `ToolResult.content` → `(elided: N bytes of tool result)`
|
||||
/// - `ToolCalls.text` → `(elided: N bytes of assistant prose)`
|
||||
/// - `Text` (assistant) → `(elided: N bytes of assistant prose)`
|
||||
///
|
||||
/// Already-tiny payloads are skipped — eliding a 50-byte string
|
||||
/// would *grow* it once the marker is in place.
|
||||
fn elide_in_place(msg: &mut Message) -> bool {
|
||||
match &mut msg.content {
|
||||
MessageContent::ToolResult { content, .. } => {
|
||||
if content.len() < ELIDE_MIN_CHARS {
|
||||
return false;
|
||||
}
|
||||
*content = format!("(elided: {} bytes of tool result)", content.len());
|
||||
true
|
||||
}
|
||||
MessageContent::ToolCalls { text, .. } => match text {
|
||||
Some(t) if t.len() >= ELIDE_MIN_CHARS => {
|
||||
*text = Some(format!("(elided: {} bytes of assistant prose)", t.len()));
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
},
|
||||
MessageContent::Text { text } => {
|
||||
if text.len() < ELIDE_MIN_CHARS {
|
||||
return false;
|
||||
}
|
||||
*text = format!("(elided: {} bytes of assistant prose)", text.len());
|
||||
true
|
||||
}
|
||||
MessageContent::MultiPart { .. } => {
|
||||
// MultiPart messages today only exist as User turns,
|
||||
// and User turns are protected by the role check in
|
||||
// `compact_to_budget` — so this branch is unreachable
|
||||
// for current call sites. Returning false keeps the
|
||||
// unreachable path benign if a future stage starts
|
||||
// emitting MultiPart on other roles.
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::ToolCall;
|
||||
|
||||
fn sys(text: &str) -> Message {
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text { text: text.into() },
|
||||
}
|
||||
}
|
||||
fn user(text: &str) -> Message {
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text { text: text.into() },
|
||||
}
|
||||
}
|
||||
fn assistant_text(text: &str) -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text { text: text.into() },
|
||||
}
|
||||
}
|
||||
fn assistant_calls(text: Option<&str>, name: &str, args: &str, id: &str) -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::ToolCalls {
|
||||
text: text.map(|s| s.to_string()),
|
||||
calls: vec![ToolCall {
|
||||
id: id.into(),
|
||||
name: name.into(),
|
||||
arguments: args.into(),
|
||||
}],
|
||||
},
|
||||
}
|
||||
}
|
||||
fn tool_result(id: &str, body: &str) -> Message {
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::ToolResult {
|
||||
tool_call_id: id.into(),
|
||||
content: body.into(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn under_budget_is_a_no_op_clone() {
|
||||
let msgs = vec![sys("you are an agent"), user("hi"), assistant_text("hello")];
|
||||
let (out, stats) = compact_to_budget(&msgs, 10_000);
|
||||
assert_eq!(stats.elided_messages, 0);
|
||||
assert_eq!(stats.original_tokens, stats.final_tokens);
|
||||
assert_eq!(out.len(), msgs.len());
|
||||
// Strings unchanged.
|
||||
match &out[2].content {
|
||||
MessageContent::Text { text } => assert_eq!(text, "hello"),
|
||||
other => panic!("expected Text, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn elides_old_tool_result_before_old_assistant_prose() {
|
||||
// History: sys, user, assistant_calls, big_tool_result,
|
||||
// assistant_with_big_text, user, assistant_calls,
|
||||
// small_tool_result.
|
||||
// KEEP_TAIL=4 protects the last four; the big tool result
|
||||
// sits in the prunable range and should go first because
|
||||
// pass 0 (tool results) runs before pass 1 (prose).
|
||||
let big_result = "X".repeat(4096);
|
||||
let big_prose = "Y".repeat(2048);
|
||||
let msgs = vec![
|
||||
sys("preamble"),
|
||||
user("first ask"),
|
||||
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "c0"),
|
||||
tool_result("c0", &big_result),
|
||||
assistant_text(&big_prose),
|
||||
user("follow up"),
|
||||
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "c1"),
|
||||
tool_result("c1", "short result body"),
|
||||
];
|
||||
let before = total_tokens(&msgs);
|
||||
// Force compaction by setting budget well below current.
|
||||
let budget = before / 2;
|
||||
let (out, stats) = compact_to_budget(&msgs, budget);
|
||||
|
||||
assert!(
|
||||
stats.elided_messages >= 1,
|
||||
"expected at least one elision, got {stats:?}"
|
||||
);
|
||||
// The big tool result must be elided (oldest fat target).
|
||||
match &out[3].content {
|
||||
MessageContent::ToolResult { content, .. } => {
|
||||
assert!(
|
||||
content.starts_with("(elided:"),
|
||||
"tool result not elided: {content:?}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected ToolResult, got {other:?}"),
|
||||
}
|
||||
// Last four messages must be untouched.
|
||||
assert!(matches!(
|
||||
&out[out.len() - 1].content,
|
||||
MessageContent::ToolResult { content, .. } if content == "short result body"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn never_elides_system_or_user_turns() {
|
||||
let big_user = "U".repeat(8192);
|
||||
let msgs = vec![sys("preamble"), user(&big_user), assistant_text("ok")];
|
||||
let budget = 10; // way below — forces all possible elision
|
||||
let (out, _stats) = compact_to_budget(&msgs, budget);
|
||||
// System unchanged.
|
||||
match &out[0].content {
|
||||
MessageContent::Text { text } => assert_eq!(text, "preamble"),
|
||||
other => panic!("expected Text, got {other:?}"),
|
||||
}
|
||||
// User unchanged even though it's huge.
|
||||
match &out[1].content {
|
||||
MessageContent::Text { text } => assert_eq!(text.len(), big_user.len()),
|
||||
other => panic!("expected Text, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserves_tool_call_id_pairing_after_elision() {
|
||||
// OpenAI strict mode rejects a tool-result whose tool_call_id
|
||||
// doesn't match a preceding assistant tool_call. Elision
|
||||
// must not break that linkage.
|
||||
let big = "Z".repeat(4096);
|
||||
let msgs = vec![
|
||||
sys("preamble"),
|
||||
user("first"),
|
||||
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "call_42"),
|
||||
tool_result("call_42", &big),
|
||||
// Tail messages.
|
||||
user("next"),
|
||||
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "call_43"),
|
||||
tool_result("call_43", "ok"),
|
||||
assistant_text("done"),
|
||||
];
|
||||
let budget = total_tokens(&msgs) / 3;
|
||||
let (out, _stats) = compact_to_budget(&msgs, budget);
|
||||
// The assistant call and its result both carry call_42.
|
||||
let call_id = match &out[2].content {
|
||||
MessageContent::ToolCalls { calls, .. } => calls[0].id.clone(),
|
||||
other => panic!("expected ToolCalls, got {other:?}"),
|
||||
};
|
||||
match &out[3].content {
|
||||
MessageContent::ToolResult { tool_call_id, .. } => {
|
||||
assert_eq!(tool_call_id, &call_id, "pairing broken");
|
||||
}
|
||||
other => panic!("expected ToolResult, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_tokens_grows_with_content() {
|
||||
let small = sys("hi");
|
||||
let large = sys(&"x".repeat(10_000));
|
||||
assert!(estimate_tokens(&large) > estimate_tokens(&small) * 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn elide_in_place_skips_short_content() {
|
||||
let mut m = tool_result("c0", "tiny");
|
||||
assert!(!elide_in_place(&mut m));
|
||||
match m.content {
|
||||
MessageContent::ToolResult { content, .. } => assert_eq!(content, "tiny"),
|
||||
other => panic!("expected ToolResult, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_best_effort_when_budget_unmeetable() {
|
||||
// Single huge user message that cannot be elided. Budget 10.
|
||||
// We don't error — we return what we have and let upstream
|
||||
// refuse the prompt with its own error.
|
||||
let big_user = "U".repeat(100_000);
|
||||
let msgs = vec![sys("preamble"), user(&big_user)];
|
||||
let (out, stats) = compact_to_budget(&msgs, 10);
|
||||
assert_eq!(out.len(), msgs.len());
|
||||
assert!(stats.final_tokens > 10, "still over budget by design");
|
||||
}
|
||||
}
|
||||
424
crates/helexa-acp/src/config.rs
Normal file
424
crates/helexa-acp/src/config.rs
Normal file
@@ -0,0 +1,424 @@
|
||||
//! Configuration for the helexa-acp bridge.
|
||||
//!
|
||||
//! Loaded from `$XDG_CONFIG_HOME/helexa-acp/config.toml` (or
|
||||
//! `~/.config/helexa-acp/config.toml` as a fallback). If no config file
|
||||
//! exists, falls back to building a single anonymous endpoint from env
|
||||
//! vars — that keeps "just point at one cortex" frictionless without
|
||||
//! requiring a config file on disk.
|
||||
//!
|
||||
//! The design goal is "the missing ACP binary for users with multiple
|
||||
//! API endpoints (possibly on a private LAN, possibly mixing wire
|
||||
//! types)". Hence: every endpoint is named, has its own wire API, and
|
||||
//! has its own default model. The agent's selected model id can be
|
||||
//! prefixed `endpoint:model` to route across endpoints; a bare
|
||||
//! `model` falls through to the configured `default_endpoint`.
|
||||
//!
|
||||
//! ### Example TOML
|
||||
//!
|
||||
//! ```toml
|
||||
//! default_endpoint = "helexa"
|
||||
//!
|
||||
//! [[endpoints]]
|
||||
//! name = "helexa"
|
||||
//! base_url = "http://hanzalova.internal:31313/v1"
|
||||
//! wire_api = "openai-chat"
|
||||
//! default_model = "helexa/large"
|
||||
//!
|
||||
//! [[endpoints]]
|
||||
//! name = "openrouter"
|
||||
//! base_url = "https://openrouter.ai/api/v1"
|
||||
//! wire_api = "openai-chat"
|
||||
//! api_key_env = "OPENROUTER_API_KEY"
|
||||
//! default_model = "anthropic/claude-opus-4"
|
||||
//!
|
||||
//! [[endpoints]]
|
||||
//! name = "lmstudio"
|
||||
//! base_url = "http://localhost:1234/v1"
|
||||
//! wire_api = "openai-chat"
|
||||
//! default_model = "auto"
|
||||
//! ```
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use url::Url;
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "http://hanzalova.internal:31313/v1";
|
||||
const DEFAULT_MODEL: &str = "helexa/large";
|
||||
const DEFAULT_ENDPOINT_NAME: &str = "default";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
/// Name of the endpoint used when a request doesn't pick one
|
||||
/// explicitly. Must reference an entry in `endpoints`. Defaults to
|
||||
/// the first endpoint declared if unset.
|
||||
#[serde(default)]
|
||||
pub default_endpoint: Option<String>,
|
||||
/// Per-endpoint configuration. At least one entry is required.
|
||||
#[serde(default)]
|
||||
pub endpoints: Vec<EndpointConfig>,
|
||||
/// Optional path to a system-prompt file. When unset, the built-in
|
||||
/// default prompt from `prompt.rs` is used.
|
||||
#[serde(default)]
|
||||
pub system_prompt_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EndpointConfig {
|
||||
/// Short identifier used in `endpoint:model` routing and in logs.
|
||||
pub name: String,
|
||||
/// Base URL of the OpenAI-compatible API. Must include the `/v1`
|
||||
/// (or equivalent) suffix — paths like `chat/completions` and
|
||||
/// `models` are joined onto this.
|
||||
pub base_url: Url,
|
||||
/// Wire protocol the endpoint speaks. Phase 1 supports
|
||||
/// [`WireApi::OpenAiChat`] only; `openai-responses` and
|
||||
/// `anthropic-messages` land later behind their own providers.
|
||||
#[serde(default)]
|
||||
pub wire_api: WireApi,
|
||||
/// Model to use when the client hasn't picked one via
|
||||
/// `session/set_model`.
|
||||
#[serde(default)]
|
||||
pub default_model: Option<String>,
|
||||
/// Static API key to send as `Authorization: Bearer …`. Prefer
|
||||
/// `api_key_env` for anything sensitive — keys in plain TOML are a
|
||||
/// liability.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Env var name to read for the API key. Resolved at startup so a
|
||||
/// missing env var yields a clear error rather than silent
|
||||
/// unauthenticated calls.
|
||||
#[serde(default)]
|
||||
pub api_key_env: Option<String>,
|
||||
/// Cap on the model's output tokens per turn. `None` lets the
|
||||
/// upstream pick its own default (cortex/neuron's default is
|
||||
/// often small enough to trip Zed's "Output Limit Reached" on
|
||||
/// long responses). Set to e.g. `32768` to let the model
|
||||
/// produce longer turns. Goes into the OpenAI `max_tokens`
|
||||
/// request field.
|
||||
#[serde(default)]
|
||||
pub max_tokens: Option<u64>,
|
||||
/// Model context window in tokens (prompt + response). When set,
|
||||
/// the agent compacts conversation history before each completion
|
||||
/// so the prompt fits within `context_window - max_tokens - safety`
|
||||
/// tokens — long sessions on small-context local models (Qwen3 at
|
||||
/// 32 K) survive past the first few tool-call rounds rather than
|
||||
/// dying with `prompt_too_long`. `None` disables compaction.
|
||||
#[serde(default)]
|
||||
pub context_window: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub enum WireApi {
|
||||
/// `POST {base}/chat/completions` returning OpenAI-format SSE.
|
||||
/// Compatible with cortex, LM Studio, Ollama (compat mode),
|
||||
/// OpenRouter, OpenAI itself.
|
||||
#[default]
|
||||
#[serde(rename = "openai-chat")]
|
||||
OpenAiChat,
|
||||
/// `POST {base}/responses` — OpenAI's newer Responses API. Not
|
||||
/// implemented yet; the variant is reserved so endpoint configs
|
||||
/// can be authored ahead of provider support.
|
||||
#[serde(rename = "openai-responses")]
|
||||
OpenAiResponses,
|
||||
/// `POST {base}/messages` — Anthropic format. Reserved.
|
||||
#[serde(rename = "anthropic-messages")]
|
||||
AnthropicMessages,
|
||||
}
|
||||
|
||||
impl EndpointConfig {
|
||||
/// Resolve the API key from `api_key` (literal) or `api_key_env`
|
||||
/// (env-var lookup). Returns `Ok(None)` when neither is set;
|
||||
/// `Err` when `api_key_env` references a missing variable.
|
||||
pub fn resolve_api_key(&self) -> anyhow::Result<Option<String>> {
|
||||
if let Some(literal) = &self.api_key {
|
||||
return Ok(Some(literal.clone()));
|
||||
}
|
||||
if let Some(var) = &self.api_key_env {
|
||||
return Ok(Some(std::env::var(var).with_context(|| {
|
||||
format!(
|
||||
"endpoint '{}' references missing env var {}",
|
||||
self.name, var
|
||||
)
|
||||
})?));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// `{base_url}/chat/completions`.
|
||||
pub fn chat_completions_url(&self) -> Url {
|
||||
join_segments(&self.base_url, &["chat", "completions"])
|
||||
}
|
||||
|
||||
/// `{base_url}/responses` — OpenAI Responses API endpoint.
|
||||
pub fn responses_url(&self) -> Url {
|
||||
join_segments(&self.base_url, &["responses"])
|
||||
}
|
||||
|
||||
/// `{base_url}/models`. Called from `Provider::list_models`, which
|
||||
/// Stage 4 wires into the model-picker dropdown; until then it's
|
||||
/// reachable code with no in-tree callers.
|
||||
#[allow(dead_code)]
|
||||
pub fn models_url(&self) -> Url {
|
||||
join_segments(&self.base_url, &["models"])
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Load from TOML at the standard config path, or build from env
|
||||
/// vars if no file exists. Env-fallback yields a single endpoint
|
||||
/// named `"default"`.
|
||||
pub fn load() -> anyhow::Result<Self> {
|
||||
let path = config_path();
|
||||
if let Some(path) = &path
|
||||
&& path.exists()
|
||||
{
|
||||
return Self::from_file(path);
|
||||
}
|
||||
Self::from_env()
|
||||
}
|
||||
|
||||
/// Single-endpoint config constructed from `HELEXA_ACP_BASE_URL`,
|
||||
/// `HELEXA_ACP_MODEL`, `HELEXA_ACP_API_KEY`,
|
||||
/// `HELEXA_ACP_SYSTEM_PROMPT_PATH`, `HELEXA_ACP_MAX_TOKENS`.
|
||||
pub fn from_env() -> anyhow::Result<Self> {
|
||||
let base_url = std::env::var("HELEXA_ACP_BASE_URL")
|
||||
.ok()
|
||||
.unwrap_or_else(|| DEFAULT_BASE_URL.into());
|
||||
let base_url = Url::parse(&base_url)
|
||||
.with_context(|| format!("HELEXA_ACP_BASE_URL is not a valid URL ({base_url})"))?;
|
||||
let default_model = std::env::var("HELEXA_ACP_MODEL")
|
||||
.ok()
|
||||
.unwrap_or_else(|| DEFAULT_MODEL.into());
|
||||
let api_key = std::env::var("HELEXA_ACP_API_KEY")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty());
|
||||
let system_prompt_path = std::env::var("HELEXA_ACP_SYSTEM_PROMPT_PATH")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(PathBuf::from);
|
||||
let max_tokens = std::env::var("HELEXA_ACP_MAX_TOKENS")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| {
|
||||
s.parse::<u64>().with_context(|| {
|
||||
format!("HELEXA_ACP_MAX_TOKENS is not a positive integer ({s})")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
let context_window = std::env::var("HELEXA_ACP_CONTEXT_WINDOW")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| {
|
||||
s.parse::<usize>().with_context(|| {
|
||||
format!("HELEXA_ACP_CONTEXT_WINDOW is not a positive integer ({s})")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Self {
|
||||
default_endpoint: Some(DEFAULT_ENDPOINT_NAME.into()),
|
||||
endpoints: vec![EndpointConfig {
|
||||
name: DEFAULT_ENDPOINT_NAME.into(),
|
||||
base_url,
|
||||
wire_api: WireApi::OpenAiChat,
|
||||
default_model: Some(default_model),
|
||||
api_key,
|
||||
api_key_env: None,
|
||||
max_tokens,
|
||||
context_window,
|
||||
}],
|
||||
system_prompt_path,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_file(path: &Path) -> anyhow::Result<Self> {
|
||||
let text = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("read config {}", path.display()))?;
|
||||
let mut cfg: Self =
|
||||
toml::from_str(&text).with_context(|| format!("parse config {}", path.display()))?;
|
||||
cfg.validate()?;
|
||||
Ok(cfg)
|
||||
}
|
||||
|
||||
fn validate(&mut self) -> anyhow::Result<()> {
|
||||
if self.endpoints.is_empty() {
|
||||
return Err(anyhow!("config has no [[endpoints]] entries"));
|
||||
}
|
||||
for (i, ep) in self.endpoints.iter().enumerate() {
|
||||
if ep.name.is_empty() {
|
||||
return Err(anyhow!("endpoints[{i}] has empty name"));
|
||||
}
|
||||
if ep.name.contains(':') {
|
||||
return Err(anyhow!(
|
||||
"endpoints[{i}].name '{}' contains ':' which would clash \
|
||||
with the endpoint:model selector syntax",
|
||||
ep.name
|
||||
));
|
||||
}
|
||||
}
|
||||
// Pick a default endpoint if none was named.
|
||||
if self.default_endpoint.is_none() {
|
||||
self.default_endpoint = Some(self.endpoints[0].name.clone());
|
||||
}
|
||||
let default_name = self.default_endpoint.as_deref().unwrap();
|
||||
if !self.endpoints.iter().any(|e| e.name == default_name) {
|
||||
return Err(anyhow!(
|
||||
"default_endpoint '{default_name}' is not declared in [[endpoints]]"
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Look up an endpoint by name. Returns `None` if not configured.
|
||||
pub fn endpoint(&self, name: &str) -> Option<&EndpointConfig> {
|
||||
self.endpoints.iter().find(|e| e.name == name)
|
||||
}
|
||||
|
||||
/// The default endpoint (guaranteed to exist after `validate`).
|
||||
pub fn default_endpoint(&self) -> &EndpointConfig {
|
||||
let name = self
|
||||
.default_endpoint
|
||||
.as_deref()
|
||||
.expect("default_endpoint set by validate");
|
||||
self.endpoint(name)
|
||||
.expect("default_endpoint resolves after validate")
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an ACP-side `model` field into (endpoint name, raw model id).
|
||||
///
|
||||
/// `helexa:helexa/large` → (`Some("helexa")`, `"helexa/large"`).
|
||||
/// `helexa/large` → (`None`, `"helexa/large"`).
|
||||
///
|
||||
/// The split happens at the FIRST colon. Model ids commonly contain
|
||||
/// `/` (HuggingFace style) but rarely `:`; if a model id ever does, the
|
||||
/// user can quote-prefix with the default endpoint name.
|
||||
pub fn parse_model_selector(input: &str) -> (Option<&str>, &str) {
|
||||
match input.split_once(':') {
|
||||
Some((endpoint, model)) if !endpoint.is_empty() && !model.is_empty() => {
|
||||
(Some(endpoint), model)
|
||||
}
|
||||
_ => (None, input),
|
||||
}
|
||||
}
|
||||
|
||||
fn config_path() -> Option<PathBuf> {
|
||||
if let Ok(override_path) = std::env::var("HELEXA_ACP_CONFIG_PATH") {
|
||||
return Some(PathBuf::from(override_path));
|
||||
}
|
||||
let xdg = std::env::var("XDG_CONFIG_HOME")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty());
|
||||
let base = xdg.map(PathBuf::from).or_else(|| {
|
||||
std::env::var("HOME")
|
||||
.ok()
|
||||
.map(|h| PathBuf::from(h).join(".config"))
|
||||
})?;
|
||||
Some(base.join("helexa-acp").join("config.toml"))
|
||||
}
|
||||
|
||||
fn join_segments(base: &Url, segments: &[&str]) -> Url {
|
||||
let mut out = base.clone();
|
||||
if let Ok(mut path) = out.path_segments_mut() {
|
||||
path.pop_if_empty().extend(segments.iter().copied());
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn url_join_handles_trailing_slash() {
|
||||
let ep = EndpointConfig {
|
||||
name: "x".into(),
|
||||
base_url: Url::parse("http://h.internal:31313/v1").unwrap(),
|
||||
wire_api: WireApi::OpenAiChat,
|
||||
default_model: None,
|
||||
api_key: None,
|
||||
api_key_env: None,
|
||||
max_tokens: None,
|
||||
context_window: None,
|
||||
};
|
||||
assert_eq!(
|
||||
ep.chat_completions_url().as_str(),
|
||||
"http://h.internal:31313/v1/chat/completions"
|
||||
);
|
||||
assert_eq!(
|
||||
ep.models_url().as_str(),
|
||||
"http://h.internal:31313/v1/models"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_model_selector() {
|
||||
assert_eq!(
|
||||
parse_model_selector("helexa:helexa/large"),
|
||||
(Some("helexa"), "helexa/large")
|
||||
);
|
||||
assert_eq!(parse_model_selector("helexa/large"), (None, "helexa/large"));
|
||||
assert_eq!(parse_model_selector("gpt-5"), (None, "gpt-5"));
|
||||
// Edge case: a leading colon → no endpoint.
|
||||
assert_eq!(parse_model_selector(":gpt-5"), (None, ":gpt-5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_fallback_builds_single_endpoint() {
|
||||
// Don't actually set env vars (would race with other tests);
|
||||
// just confirm the default path constructs cleanly.
|
||||
unsafe {
|
||||
std::env::remove_var("HELEXA_ACP_BASE_URL");
|
||||
std::env::remove_var("HELEXA_ACP_MODEL");
|
||||
std::env::remove_var("HELEXA_ACP_API_KEY");
|
||||
}
|
||||
let cfg = Config::from_env().unwrap();
|
||||
assert_eq!(cfg.endpoints.len(), 1);
|
||||
assert_eq!(cfg.endpoints[0].name, "default");
|
||||
assert_eq!(cfg.endpoints[0].base_url.as_str(), DEFAULT_BASE_URL);
|
||||
assert_eq!(
|
||||
cfg.endpoints[0].default_model.as_deref(),
|
||||
Some(DEFAULT_MODEL)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toml_parses_multi_endpoint() {
|
||||
let toml_text = r#"
|
||||
default_endpoint = "helexa"
|
||||
|
||||
[[endpoints]]
|
||||
name = "helexa"
|
||||
base_url = "http://hanzalova.internal:31313/v1"
|
||||
default_model = "helexa/large"
|
||||
|
||||
[[endpoints]]
|
||||
name = "openrouter"
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
wire_api = "openai-chat"
|
||||
api_key_env = "OPENROUTER_API_KEY"
|
||||
default_model = "anthropic/claude-opus-4"
|
||||
"#;
|
||||
let mut cfg: Config = toml::from_str(toml_text).unwrap();
|
||||
cfg.validate().unwrap();
|
||||
assert_eq!(cfg.endpoints.len(), 2);
|
||||
assert_eq!(cfg.default_endpoint().name, "helexa");
|
||||
assert_eq!(cfg.endpoints[0].wire_api, WireApi::OpenAiChat);
|
||||
assert_eq!(
|
||||
cfg.endpoints[1].api_key_env.as_deref(),
|
||||
Some("OPENROUTER_API_KEY")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_rejects_colon_in_endpoint_name() {
|
||||
let toml_text = r#"
|
||||
[[endpoints]]
|
||||
name = "bad:name"
|
||||
base_url = "http://x/v1"
|
||||
"#;
|
||||
let mut cfg: Config = toml::from_str(toml_text).unwrap();
|
||||
let err = cfg.validate().unwrap_err();
|
||||
assert!(format!("{err}").contains("clash"));
|
||||
}
|
||||
}
|
||||
145
crates/helexa-acp/src/main.rs
Normal file
145
crates/helexa-acp/src/main.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
//! helexa-acp — Agent Client Protocol bridge for multi-endpoint LLM
|
||||
//! setups (helexa, LM Studio, Ollama, OpenRouter, OpenAI, Anthropic,
|
||||
//! …) with a clean per-endpoint wire-format selector.
|
||||
//!
|
||||
//! Speaks ACP over stdio to an editor client (Zed today). Every
|
||||
//! configured endpoint produces a wire-format-specific
|
||||
//! [`provider::Provider`] implementation; the agent loop in
|
||||
//! [`agent::Agent`] is provider-agnostic, so adding e.g. an Anthropic
|
||||
//! /v1/messages provider doesn't touch `agent.rs`.
|
||||
//!
|
||||
//! Config: `$XDG_CONFIG_HOME/helexa-acp/config.toml` for the multi-
|
||||
//! endpoint case; env vars (`HELEXA_ACP_BASE_URL`, etc.) for the
|
||||
//! single-endpoint case when no config file exists.
|
||||
|
||||
use agent_client_protocol::{Result, Stdio};
|
||||
use std::sync::Arc;
|
||||
|
||||
mod agent;
|
||||
mod compaction;
|
||||
mod config;
|
||||
mod path_util;
|
||||
mod prompt;
|
||||
mod provider;
|
||||
mod qwen3;
|
||||
mod session;
|
||||
mod store;
|
||||
mod tool_runner;
|
||||
mod tools;
|
||||
|
||||
use agent::Agent;
|
||||
use config::{Config, EndpointConfig, WireApi};
|
||||
use provider::{
|
||||
Provider, anthropic_messages::AnthropicMessagesProvider, openai_chat::OpenAIChatProvider,
|
||||
openai_responses::OpenAIResponsesProvider,
|
||||
};
|
||||
|
||||
/// Set up tracing. Logs go to stderr by default — stdout is
|
||||
/// reserved for the JSON-RPC stream. Setting `HELEXA_ACP_LOG_FILE`
|
||||
/// to an absolute path appends logs to that file instead, which is
|
||||
/// the practical way to capture debug output when the agent runs
|
||||
/// under an editor (Zed, etc.) that doesn't surface stderr.
|
||||
///
|
||||
/// `RUST_LOG` still controls levels (e.g. `helexa_acp=debug`).
|
||||
/// ANSI colours are auto-stripped when writing to a file so the log
|
||||
/// is plain text.
|
||||
fn init_tracing() {
|
||||
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"));
|
||||
|
||||
let log_file = std::env::var("HELEXA_ACP_LOG_FILE")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
match log_file {
|
||||
Some(path) => match std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)
|
||||
{
|
||||
Ok(file) => {
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::sync::Mutex::new(file))
|
||||
.with_env_filter(env_filter)
|
||||
.with_ansi(false)
|
||||
.init();
|
||||
}
|
||||
Err(e) => {
|
||||
// Fall back to stderr and shout. We don't want a
|
||||
// typo'd log path to silence the agent entirely.
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(env_filter)
|
||||
.init();
|
||||
tracing::warn!(
|
||||
path = %path,
|
||||
error = %e,
|
||||
"HELEXA_ACP_LOG_FILE could not be opened; using stderr"
|
||||
);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(env_filter)
|
||||
.init();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a provider for `endpoint` according to its declared
|
||||
/// `wire_api`. Future wire types (OpenAI Responses, Anthropic
|
||||
/// /v1/messages, Ollama native) slot in here without changing the
|
||||
/// caller.
|
||||
fn build_provider(endpoint: EndpointConfig) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
match endpoint.wire_api {
|
||||
WireApi::OpenAiChat => Ok(Arc::new(OpenAIChatProvider::new(endpoint)?)),
|
||||
WireApi::OpenAiResponses => Ok(Arc::new(OpenAIResponsesProvider::new(endpoint)?)),
|
||||
WireApi::AnthropicMessages => Ok(Arc::new(AnthropicMessagesProvider::new(endpoint)?)),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
init_tracing();
|
||||
|
||||
let cfg = Config::load()
|
||||
.map_err(|e| agent_client_protocol::util::internal_error(format!("config: {e:#}")))?;
|
||||
tracing::info!(
|
||||
endpoints = cfg.endpoints.len(),
|
||||
default_endpoint = %cfg.default_endpoint().name,
|
||||
default_model = ?cfg.default_endpoint().default_model,
|
||||
"helexa-acp starting"
|
||||
);
|
||||
|
||||
// Build a provider for each configured endpoint up-front. Cheap —
|
||||
// just sets up a reqwest::Client and resolves the API key — and
|
||||
// surfaces config mistakes (missing API key env var, unsupported
|
||||
// wire_api) before the editor even sends an initialize request.
|
||||
let mut providers: Vec<Arc<dyn Provider>> = Vec::with_capacity(cfg.endpoints.len());
|
||||
for endpoint in &cfg.endpoints {
|
||||
match build_provider(endpoint.clone()) {
|
||||
Ok(p) => {
|
||||
tracing::info!(
|
||||
endpoint = %endpoint.name,
|
||||
base_url = %endpoint.base_url,
|
||||
wire_api = ?endpoint.wire_api,
|
||||
"registered provider"
|
||||
);
|
||||
providers.push(p);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
endpoint = %endpoint.name,
|
||||
error = %format!("{e:#}"),
|
||||
"skipping endpoint with invalid config"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let agent = Agent::new(&cfg, providers)
|
||||
.await
|
||||
.map_err(|e| agent_client_protocol::util::internal_error(format!("agent: {e:#}")))?;
|
||||
agent.serve(Stdio::new()).await
|
||||
}
|
||||
192
crates/helexa-acp/src/path_util.rs
Normal file
192
crates/helexa-acp/src/path_util.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
//! Path expansion shared across every tool that takes a path.
|
||||
//!
|
||||
//! Models often emit shell-style paths like `~/git/repo/file.rs` or
|
||||
//! `$HOME/notes.md`. ACP's `fs/read_text_file` and friends — and our
|
||||
//! own local `std::fs` reads — both want a real absolute path; the
|
||||
//! `~` / `$HOME` forms reach them as literal strings and the open
|
||||
//! fails. The tool schemas already document "absolute path" but in
|
||||
//! practice the model slips up often enough that handling it
|
||||
//! server-side is the difference between "works" and "the agent is
|
||||
//! brittle".
|
||||
//!
|
||||
//! Scope is deliberately small:
|
||||
//!
|
||||
//! - `~` and `~/` (current user only — `~user` lookups would require
|
||||
//! pulling in passwd parsing).
|
||||
//! - `$HOME` and `$HOME/`.
|
||||
//!
|
||||
//! Any other shell variable (`$PWD`, `${HOME}`, …) passes through
|
||||
//! unchanged. The shell already expands them inside `bash` tool
|
||||
//! commands; for the file-tool argument fields, we deliberately
|
||||
//! limit the set so the behaviour is predictable.
|
||||
//!
|
||||
//! Falls back to the input path verbatim when `HOME` is unset
|
||||
//! (stripped-down container env). That preserves the "no surprise
|
||||
//! mutations" rule — never invent a path the caller didn't ask for.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Process-global lock for tests that mutate `HOME`. Anyone in the
|
||||
/// crate touching `HOME` must hold this for the duration of the
|
||||
/// read-modify-restore window — otherwise concurrent `cargo test`
|
||||
/// workers race and flake.
|
||||
///
|
||||
/// Only built into the test binaries. Production code never mutates
|
||||
/// env vars.
|
||||
#[cfg(test)]
|
||||
pub(crate) static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||
|
||||
/// Expand `~`, `~/`, `$HOME`, and `$HOME/` prefixes against the
|
||||
/// current user's home directory. All other inputs pass through
|
||||
/// unchanged.
|
||||
///
|
||||
/// Returns the input verbatim if `HOME` isn't set in the env.
|
||||
pub fn expand_path(input: &Path) -> PathBuf {
|
||||
let Some(s) = input.to_str() else {
|
||||
return input.to_path_buf();
|
||||
};
|
||||
let Ok(home) = std::env::var("HOME") else {
|
||||
return input.to_path_buf();
|
||||
};
|
||||
let home = PathBuf::from(home);
|
||||
if s == "~" || s == "$HOME" {
|
||||
return home;
|
||||
}
|
||||
if let Some(rest) = s.strip_prefix("~/") {
|
||||
return home.join(rest);
|
||||
}
|
||||
if let Some(rest) = s.strip_prefix("$HOME/") {
|
||||
return home.join(rest);
|
||||
}
|
||||
input.to_path_buf()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Set HOME for the duration of the test. Tests using this run
|
||||
/// serially under the crate-wide [`ENV_LOCK`] because env
|
||||
/// mutation isn't thread-safe — `cargo test` parallel workers
|
||||
/// would race without it.
|
||||
fn with_home<F: FnOnce()>(home: &str, body: F) {
|
||||
let _g = ENV_LOCK.lock().unwrap();
|
||||
let prior = std::env::var("HOME").ok();
|
||||
// SAFETY: tests touch process-global env. The mutex
|
||||
// serialises access; sub-threads in other test modules
|
||||
// touching HOME aren't expected (none in this crate).
|
||||
unsafe {
|
||||
std::env::set_var("HOME", home);
|
||||
}
|
||||
body();
|
||||
unsafe {
|
||||
match prior {
|
||||
Some(p) => std::env::set_var("HOME", p),
|
||||
None => std::env::remove_var("HOME"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expands_tilde_slash() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("~/git/repo/file.rs")),
|
||||
PathBuf::from("/home/me/git/repo/file.rs")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expands_bare_tilde() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(expand_path(Path::new("~")), PathBuf::from("/home/me"));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expands_dollar_home_slash() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("$HOME/notes.md")),
|
||||
PathBuf::from("/home/me/notes.md")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expands_bare_dollar_home() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(expand_path(Path::new("$HOME")), PathBuf::from("/home/me"));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn absolute_path_passes_through() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("/etc/hostname")),
|
||||
PathBuf::from("/etc/hostname")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relative_path_passes_through() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("src/main.rs")),
|
||||
PathBuf::from("src/main.rs")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_user_form_not_expanded() {
|
||||
// ~other is shell sugar for /home/other and would require
|
||||
// passwd parsing to resolve. Out of scope — pass it
|
||||
// through and let the open fail with a clear error.
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("~other/x")),
|
||||
PathBuf::from("~other/x")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_home_env_passes_through() {
|
||||
// Share the same crate-wide lock as `with_home` — otherwise
|
||||
// a parallel test setting HOME races this clear-and-assert
|
||||
// window.
|
||||
let _g = ENV_LOCK.lock().unwrap();
|
||||
let prior = std::env::var("HOME").ok();
|
||||
// SAFETY: serialised by LOCK above.
|
||||
unsafe {
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
assert_eq!(
|
||||
expand_path(Path::new("~/git/repo")),
|
||||
PathBuf::from("~/git/repo")
|
||||
);
|
||||
unsafe {
|
||||
if let Some(p) = prior {
|
||||
std::env::set_var("HOME", p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dollar_other_var_not_expanded() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("$PWD/file")),
|
||||
PathBuf::from("$PWD/file")
|
||||
);
|
||||
assert_eq!(
|
||||
expand_path(Path::new("${HOME}/file")),
|
||||
PathBuf::from("${HOME}/file")
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
274
crates/helexa-acp/src/prompt.rs
Normal file
274
crates/helexa-acp/src/prompt.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
//! System prompt assembly.
|
||||
//!
|
||||
//! The system message has two parts:
|
||||
//!
|
||||
//! 1. A short human-readable preamble (working directory, style
|
||||
//! instructions). Either the built-in [`DEFAULT_PROMPT`] or a
|
||||
//! user-supplied file at `HELEXA_ACP_SYSTEM_PROMPT_PATH` /
|
||||
//! `system_prompt_path`. `{cwd}` is substituted in both.
|
||||
//! 2. A `# Tools` block in Qwen3 Hermes format (see [`crate::qwen3`])
|
||||
//! describing the available functions. This is what makes the
|
||||
//! model actually call them — neuron/cortex don't honour the
|
||||
//! OpenAI `tools` API field, so the tool list has to live in the
|
||||
//! prompt itself.
|
||||
|
||||
use agent_client_protocol::schema::SessionModeId;
|
||||
use anyhow::Context;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::provider::ToolSpec;
|
||||
use crate::qwen3;
|
||||
use crate::session::MODE_PLAN;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "\
|
||||
You are helexa-acp, a coding assistant working inside an editor.
|
||||
|
||||
Working directory: {cwd}
|
||||
|
||||
Use the tools described below whenever the user's request involves
|
||||
looking at or modifying files, or running commands. Do not ask the
|
||||
user to paste file contents you could read yourself. All file paths
|
||||
must be absolute. Writes and shell commands may prompt the user for
|
||||
permission depending on the session mode.
|
||||
|
||||
Be concise; the user is reading your output in an editor pane.";
|
||||
|
||||
/// Build the system prompt for a session.
|
||||
///
|
||||
/// - `cwd`: session working directory (substituted for `{cwd}` in
|
||||
/// the preamble — both the default and any user-supplied template).
|
||||
/// - `override_path`: path to a user-supplied template, already
|
||||
/// resolved by [`crate::config::Config`]. The `# Tools` block is
|
||||
/// appended *after* the user's template so a custom preamble
|
||||
/// still gets the tool descriptions the model needs.
|
||||
/// - `tools`: the tools to advertise. Empty list → no `# Tools`
|
||||
/// block is appended at all.
|
||||
/// - `mode`: current session mode. When the mode is [`MODE_PLAN`]
|
||||
/// a plan-mode addendum describing the restrictions and the
|
||||
/// completion menu is appended *after* the `# Tools` block so it
|
||||
/// is the last thing the model reads before user input.
|
||||
/// - `plan_dir`: resolved plan directory for the cwd. Only consulted
|
||||
/// when `mode == MODE_PLAN`. `None` means the plan directory could
|
||||
/// not be resolved (no `HOME` / `XDG_DATA_HOME`) — the addendum
|
||||
/// still renders but with a placeholder so the model knows to
|
||||
/// surface the error to the user rather than guess a path.
|
||||
pub fn build_system_prompt(
|
||||
cwd: &Path,
|
||||
override_path: Option<&Path>,
|
||||
tools: &[ToolSpec],
|
||||
mode: &SessionModeId,
|
||||
plan_dir: Option<&Path>,
|
||||
) -> anyhow::Result<String> {
|
||||
let template = match override_path {
|
||||
Some(path) => std::fs::read_to_string(path)
|
||||
.with_context(|| format!("read system prompt from {}", path.display()))?,
|
||||
None => DEFAULT_PROMPT.to_string(),
|
||||
};
|
||||
let mut prompt = template.replace("{cwd}", &cwd.display().to_string());
|
||||
prompt.push_str(&qwen3::render_tool_block(tools));
|
||||
if mode.0.as_ref() == MODE_PLAN {
|
||||
prompt.push_str(&render_plan_mode_block(plan_dir));
|
||||
}
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
/// Plan-mode instruction block. Tells the model:
|
||||
///
|
||||
/// 1. Where it may write — only inside `plan_dir`.
|
||||
/// 2. What it may *not* do — bash is disabled; writes outside
|
||||
/// `plan_dir` are refused by the runtime.
|
||||
/// 3. How to finish — emit the 3-option menu so the user can
|
||||
/// switch modes and either kick off implementation (with or
|
||||
/// without permission prompts) or keep iterating on the plan.
|
||||
fn render_plan_mode_block(plan_dir: Option<&Path>) -> String {
|
||||
let plan_path = plan_dir
|
||||
.map(|p| p.display().to_string())
|
||||
.unwrap_or_else(|| "<plan directory could not be resolved — tell the user>".to_string());
|
||||
format!(
|
||||
"\n\n# Plan mode\n\
|
||||
\n\
|
||||
You are in **plan mode**. Your task is to draft a written\n\
|
||||
implementation plan for the user; you must NOT modify any\n\
|
||||
project files or run shell commands.\n\
|
||||
\n\
|
||||
Rules in plan mode:\n\
|
||||
\n\
|
||||
- `read_file` and `list_dir` are unrestricted — use them to\n\
|
||||
explore the codebase as needed.\n\
|
||||
- `write_file` and `edit_file` are allowed ONLY under the\n\
|
||||
plan directory: `{plan_path}`. The runtime will refuse any\n\
|
||||
write outside it.\n\
|
||||
- `bash` is disabled. Do not call it.\n\
|
||||
\n\
|
||||
Write the plan as one or more Markdown files under\n\
|
||||
`{plan_path}`. Use descriptive filenames\n\
|
||||
(`01-overview.md`, `02-data-model.md`, etc.). It is fine to\n\
|
||||
iterate — overwrite the file when you refine a section.\n\
|
||||
\n\
|
||||
When the plan is complete, do NOT begin implementation.\n\
|
||||
Instead, end your turn with this menu, verbatim, so the\n\
|
||||
user can choose how to proceed:\n\
|
||||
\n\
|
||||
---\n\
|
||||
**Plan complete.** To proceed, switch the session mode in\n\
|
||||
the agent dropdown and send a follow-up message:\n\
|
||||
\n\
|
||||
1. **Bypass Permissions** — implement the plan now, skipping\n\
|
||||
per-tool permission prompts.\n\
|
||||
2. **Default** — implement the plan now, prompting before\n\
|
||||
each write or shell command.\n\
|
||||
3. **Plan** (stay here) — refine the plan; reply with the\n\
|
||||
change you want and I will revise it.\n\
|
||||
---\n"
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::session::{MODE_DEFAULT, MODE_PLAN};
|
||||
use std::io::Write;
|
||||
|
||||
fn default_mode() -> SessionModeId {
|
||||
SessionModeId::new(MODE_DEFAULT)
|
||||
}
|
||||
fn plan_mode() -> SessionModeId {
|
||||
SessionModeId::new(MODE_PLAN)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_prompt_substitutes_cwd() {
|
||||
let prompt =
|
||||
build_system_prompt(Path::new("/home/me/proj"), None, &[], &default_mode(), None)
|
||||
.unwrap();
|
||||
assert!(
|
||||
prompt.contains("/home/me/proj"),
|
||||
"cwd not interpolated: {prompt}"
|
||||
);
|
||||
assert!(prompt.contains("helexa-acp"));
|
||||
assert!(
|
||||
!prompt.contains("{cwd}"),
|
||||
"left-over placeholder in default prompt"
|
||||
);
|
||||
// With no tools, the # Tools block is absent.
|
||||
assert!(!prompt.contains("# Tools"));
|
||||
// Default mode does not get the plan-mode addendum.
|
||||
assert!(!prompt.contains("# Plan mode"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tools_are_appended_in_hermes_format() {
|
||||
let spec = ToolSpec {
|
||||
name: "read_file".into(),
|
||||
description: "Read a file.".into(),
|
||||
parameters: serde_json::json!({"type":"object","properties":{}, "required":[]}),
|
||||
};
|
||||
let prompt =
|
||||
build_system_prompt(Path::new("/x"), None, &[spec], &default_mode(), None).unwrap();
|
||||
assert!(prompt.contains("# Tools"));
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("\"name\":\"read_file\""));
|
||||
assert!(prompt.contains("<tool_call>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn override_path_is_read_and_templated() {
|
||||
let mut tmp = tempfile_in_target("prompt.txt");
|
||||
tmp.write_all(b"custom prompt for {cwd} only").unwrap();
|
||||
tmp.flush().unwrap();
|
||||
|
||||
let path = tmp.path().to_path_buf();
|
||||
drop(tmp);
|
||||
|
||||
let prompt = build_system_prompt(
|
||||
Path::new("/etc"),
|
||||
Some(path.as_path()),
|
||||
&[],
|
||||
&default_mode(),
|
||||
None,
|
||||
)
|
||||
.expect("read override");
|
||||
assert_eq!(prompt, "custom prompt for /etc only");
|
||||
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_override_path_errors() {
|
||||
let err = build_system_prompt(
|
||||
Path::new("/tmp"),
|
||||
Some(Path::new("/definitely/not/a/real/path")),
|
||||
&[],
|
||||
&default_mode(),
|
||||
None,
|
||||
)
|
||||
.unwrap_err();
|
||||
assert!(format!("{err:#}").contains("read system prompt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_mode_addendum_includes_plan_dir_and_menu() {
|
||||
let plan_dir = Path::new("/home/me/.local/share/helexa-acp/plans/proj-deadbeef");
|
||||
let prompt = build_system_prompt(
|
||||
Path::new("/home/me/proj"),
|
||||
None,
|
||||
&[],
|
||||
&plan_mode(),
|
||||
Some(plan_dir),
|
||||
)
|
||||
.unwrap();
|
||||
assert!(prompt.contains("# Plan mode"));
|
||||
assert!(
|
||||
prompt.contains(plan_dir.to_str().unwrap()),
|
||||
"plan dir not interpolated: {prompt}"
|
||||
);
|
||||
// The 3-option menu must be present so the model emits it verbatim.
|
||||
assert!(prompt.contains("Bypass Permissions"));
|
||||
assert!(prompt.contains("**Default**"));
|
||||
assert!(prompt.contains("3. **Plan**"));
|
||||
// Bash disabled instruction must be present.
|
||||
assert!(prompt.contains("`bash` is disabled"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_mode_addendum_handles_unresolved_plan_dir() {
|
||||
let prompt =
|
||||
build_system_prompt(Path::new("/home/me/proj"), None, &[], &plan_mode(), None).unwrap();
|
||||
assert!(prompt.contains("# Plan mode"));
|
||||
assert!(prompt.contains("could not be resolved"));
|
||||
}
|
||||
|
||||
/// Tiny temp-file helper that doesn't pull in the `tempfile` crate.
|
||||
/// Writes under `target/` so it's cleaned up by `cargo clean`.
|
||||
fn tempfile_in_target(name: &str) -> TempHandle {
|
||||
let base = std::env::var("CARGO_TARGET_TMPDIR")
|
||||
.ok()
|
||||
.map(std::path::PathBuf::from)
|
||||
.unwrap_or_else(std::env::temp_dir);
|
||||
let _ = std::fs::create_dir_all(&base);
|
||||
let pid = std::process::id();
|
||||
let path = base.join(format!("helexa-acp-{pid}-{name}"));
|
||||
let file = std::fs::File::create(&path).expect("create temp file");
|
||||
TempHandle { file, path }
|
||||
}
|
||||
|
||||
struct TempHandle {
|
||||
file: std::fs::File,
|
||||
path: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl TempHandle {
|
||||
fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for TempHandle {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.file.write(buf)
|
||||
}
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
self.file.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
1200
crates/helexa-acp/src/provider/anthropic_messages.rs
Normal file
1200
crates/helexa-acp/src/provider/anthropic_messages.rs
Normal file
File diff suppressed because it is too large
Load Diff
230
crates/helexa-acp/src/provider/mod.rs
Normal file
230
crates/helexa-acp/src/provider/mod.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
//! Provider trait — the seam between the ACP-side agent loop and
|
||||
//! whatever wire protocol an endpoint actually speaks.
|
||||
//!
|
||||
//! Every concrete provider (OpenAI chat completions, OpenAI Responses,
|
||||
//! Anthropic /v1/messages, Ollama native, …) implements
|
||||
//! [`Provider`]. The agent constructs a [`CompletionRequest`] using
|
||||
//! provider-agnostic types and consumes a stream of
|
||||
//! [`CompletionEvent`]s — neither end knows which wire format is on
|
||||
//! the other side of the trait.
|
||||
//!
|
||||
//! Day-1 provider: [`openai_chat::OpenAIChatProvider`]. Day-N
|
||||
//! providers slot in without touching `agent.rs`.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub mod anthropic_messages;
|
||||
pub mod openai_chat;
|
||||
pub mod openai_responses;
|
||||
|
||||
/// Provider-agnostic LLM endpoint. Implementations translate between
|
||||
/// [`CompletionRequest`] / [`CompletionEvent`] and whatever wire
|
||||
/// format their endpoint speaks.
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Endpoint name as configured by the user (e.g. `"helexa"`,
|
||||
/// `"openrouter"`). Used in logs and in the `endpoint:model`
|
||||
/// selector.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// List models available at this endpoint. Used to build the
|
||||
/// model-picker dropdown in editor clients (Stage 4). Should
|
||||
/// return quickly (cache if necessary).
|
||||
#[allow(dead_code)]
|
||||
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>>;
|
||||
|
||||
/// Run a chat completion. Returns a stream of provider-agnostic
|
||||
/// events. The stream stops when the upstream finishes, when
|
||||
/// `cancel` is fired, or when the stream is dropped.
|
||||
async fn complete(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>>;
|
||||
}
|
||||
|
||||
/// One model exposed by a provider. Constructed by `list_models` —
|
||||
/// Stage 4 is when the agent loop starts consuming it for the
|
||||
/// model-picker dropdown.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
/// Human-friendly name, if the endpoint exposes one. Otherwise
|
||||
/// `id` is used as the display name.
|
||||
#[serde(default)]
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Inputs to a completion. Provider-agnostic — concrete providers
|
||||
/// translate this into their wire format.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompletionRequest {
|
||||
/// Endpoint-local model id (without the `endpoint:` prefix).
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
/// Tools the model is allowed to call. Empty list means no tool
|
||||
/// support advertised.
|
||||
pub tools: Vec<ToolSpec>,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub max_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: MessageContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Role {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
/// Tool result message. Provider impls turn this into whatever
|
||||
/// shape the upstream wire format wants (OpenAI uses
|
||||
/// `role: "tool"` + `tool_call_id`; Anthropic uses content blocks).
|
||||
/// Stage 3 (tools) constructs this; Stage 2 never does.
|
||||
Tool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessageContent {
|
||||
/// Plain text turn (system / user / assistant). Struct variant
|
||||
/// rather than newtype so the persisted JSON has an explicit
|
||||
/// `text` field — that lets us use internal tagging on the
|
||||
/// enum, which is incompatible with newtype-of-primitive
|
||||
/// variants.
|
||||
Text { text: String },
|
||||
/// Mixed text + image user turn. Stage 5 introduces this when
|
||||
/// Zed sends an `ImageContent` block alongside the user's prompt.
|
||||
/// Providers that don't support vision should down-convert by
|
||||
/// dropping image parts and concatenating text parts.
|
||||
MultiPart { parts: Vec<MessagePart> },
|
||||
/// Assistant turn that called one or more tools. Stage 3 starts
|
||||
/// constructing this when the provider stream yields a
|
||||
/// `ToolCallStart` / `ToolCallArgsDelta` sequence.
|
||||
ToolCalls {
|
||||
/// Optional text the assistant said alongside the tool calls.
|
||||
text: Option<String>,
|
||||
calls: Vec<ToolCall>,
|
||||
},
|
||||
/// Tool result. `tool_call_id` matches the assistant's call id.
|
||||
/// Stage 3 constructs this after the tool runner finishes.
|
||||
ToolResult {
|
||||
tool_call_id: String,
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// One part of a [`MessageContent::MultiPart`] message.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessagePart {
|
||||
Text { text: String },
|
||||
Image(ImageData),
|
||||
}
|
||||
|
||||
/// Inline image attachment. `data` is base64-encoded raw image
|
||||
/// bytes; the encoder constructs an `image_url` data URI from it
|
||||
/// at request time. `uri` carries any pointer the client supplied
|
||||
/// (e.g. `file:///tmp/x.png`) — we keep it on the message for
|
||||
/// debugging / future providers but the OpenAI encoder ignores it
|
||||
/// when `data` is present (data wins, since it round-trips through
|
||||
/// every wire format).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageData {
|
||||
pub mime_type: String,
|
||||
/// Base64-encoded image bytes (no `data:` prefix, no padding
|
||||
/// stripped — exactly what `ImageContent.data` carried).
|
||||
pub data: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
/// Provider-assigned id that ties the call to its result. The
|
||||
/// Qwen3 wire format we use today doesn't carry this on the
|
||||
/// model side (calls and results are matched positionally inside
|
||||
/// a turn), so the field looks unused in the prod build — but it
|
||||
/// flows through to `MessageContent::ToolResult.tool_call_id` for
|
||||
/// history bookkeeping and a future strict-OpenAI backend will
|
||||
/// consume it directly.
|
||||
#[allow(dead_code)]
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
/// JSON-encoded arguments. Kept as a string because providers
|
||||
/// stream argument bytes incrementally and only validate at the
|
||||
/// end; the agent decodes once the call is complete.
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolSpec {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
/// JSON Schema of the arguments object.
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
/// Events emitted by a provider during a streaming completion.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CompletionEvent {
|
||||
/// Incremental visible text from the assistant.
|
||||
TextDelta(String),
|
||||
/// Incremental "reasoning" / thought text, if the model emits one
|
||||
/// (e.g. Qwen3 with `<think>` tags surfaced as a separate stream,
|
||||
/// or OpenAI reasoning models).
|
||||
ReasoningDelta(String),
|
||||
/// A new tool call has started. Stage 2 ignores the payload; the
|
||||
/// agent loop in Stage 3 reads `index` to correlate with
|
||||
/// [`Self::ToolCallArgsDelta`], `id` for the eventual tool-result
|
||||
/// turn, and `name` to dispatch the runner.
|
||||
#[allow(dead_code)]
|
||||
ToolCallStart {
|
||||
index: usize,
|
||||
id: String,
|
||||
name: String,
|
||||
},
|
||||
/// More argument bytes for a tool call already announced via
|
||||
/// [`Self::ToolCallStart`]. Stage 2 ignores; Stage 3 accumulates
|
||||
/// the bytes by `index` until the call's arguments are complete.
|
||||
#[allow(dead_code)]
|
||||
ToolCallArgsDelta { index: usize, args_delta: String },
|
||||
/// A `<tool_call>` block whose JSON couldn't be parsed even with
|
||||
/// the qwen3 module's repair attempts. The agent surfaces this
|
||||
/// as a Failed `SessionUpdate::ToolCall` card with the raw body
|
||||
/// visible (so the editor renders structured failure UI rather
|
||||
/// than dumping the body inline in the message pane), and feeds
|
||||
/// a synthetic tool-error message back into history so the
|
||||
/// model can self-correct on the next round.
|
||||
MalformedToolCall { raw: String },
|
||||
/// Stream finished. Carries the upstream `finish_reason` if it
|
||||
/// gave one (`"stop"`, `"length"`, `"tool_calls"`, …).
|
||||
Finish { reason: Option<String> },
|
||||
/// Final usage stats, if the provider supplied them. Stage 2
|
||||
/// matches the variant to drop it; Stage 6b (token metrics) is
|
||||
/// when the payload starts being read.
|
||||
#[allow(dead_code)]
|
||||
Usage(UsageStats),
|
||||
}
|
||||
|
||||
/// Token accounting reported by the provider at the end of a stream.
|
||||
/// Stage 2 doesn't surface usage anywhere — the stable `PromptResponse`
|
||||
/// has no usage field, and the unstable variant is gated. Stage 6b
|
||||
/// turns these on with Prometheus metrics.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct UsageStats {
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
1002
crates/helexa-acp/src/provider/openai_chat.rs
Normal file
1002
crates/helexa-acp/src/provider/openai_chat.rs
Normal file
File diff suppressed because it is too large
Load Diff
987
crates/helexa-acp/src/provider/openai_responses.rs
Normal file
987
crates/helexa-acp/src/provider/openai_responses.rs
Normal file
@@ -0,0 +1,987 @@
|
||||
//! OpenAI Responses API (`POST /v1/responses`) provider.
|
||||
//!
|
||||
//! Mirror image of [`super::openai_chat`]: same `Provider` trait
|
||||
//! impl, same back-pressured SSE decoder, but speaking OpenAI's
|
||||
//! newer Responses surface instead of chat completions.
|
||||
//!
|
||||
//! Differences from the chat provider, all contained in this file:
|
||||
//!
|
||||
//! - **Request encoding**: history flattens into an `input` array
|
||||
//! of typed items (`message`, `function_call`, `function_call_output`)
|
||||
//! plus a top-level `instructions` field for the system prompt.
|
||||
//! Multi-part user content stays in the same `[{type:"input_text"},
|
||||
//! {type:"input_image"}]` shape neuron's `request_to_chat` already
|
||||
//! accepts.
|
||||
//! - **Streaming decoder**: events are named (`response.created`,
|
||||
//! `response.output_text.delta`, `response.completed`, …) carried
|
||||
//! on the SSE `event:` line. The chat path's `[DONE]` terminator
|
||||
//! doesn't apply; the stream ends after `response.completed`.
|
||||
//! - **Tool calls** plumb through the `response.output_item.added`
|
||||
//! (item type `function_call`) → `response.function_call_arguments.delta`
|
||||
//! → `response.function_call_arguments.done` event sequence. The
|
||||
//! neuron candle harness doesn't synthesize these yet (tracked as
|
||||
//! issue #6), but the decoder is wired so the day the upstream
|
||||
//! does, downstream `CompletionEvent::ToolCall*` plumbing just
|
||||
//! works.
|
||||
//!
|
||||
//! Tool-name handling: the model knows its tool descriptions via
|
||||
//! the [`crate::qwen3`] system-prompt block exactly the way the chat
|
||||
//! provider does. We don't echo them in the request body because
|
||||
//! neuron currently ignores `tools` on /v1/responses (same as on
|
||||
//! /v1/chat/completions). Once neuron honours request-side tool
|
||||
//! definitions, both providers add them in the same place.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::{Stream, StreamExt, stream::BoxStream};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::{
|
||||
CompletionEvent, CompletionRequest, Message, MessageContent, MessagePart, ModelInfo, Provider,
|
||||
Role, UsageStats,
|
||||
};
|
||||
use crate::config::EndpointConfig;
|
||||
|
||||
pub struct OpenAIResponsesProvider {
|
||||
endpoint: EndpointConfig,
|
||||
#[allow(dead_code)] // Read in `complete()`'s HTTP path; tests don't stand up a server.
|
||||
api_key: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
impl OpenAIResponsesProvider {
|
||||
pub fn new(endpoint: EndpointConfig) -> anyhow::Result<Self> {
|
||||
let api_key = endpoint.resolve_api_key()?;
|
||||
let http = reqwest::Client::builder()
|
||||
// Same generous timeout as the chat provider: cortex may
|
||||
// need to cold-load a model before serving the first
|
||||
// chunk, which can be tens of seconds. Cancellation
|
||||
// handles early termination, not timeout.
|
||||
.timeout(std::time::Duration::from_secs(600))
|
||||
.build()?;
|
||||
Ok(Self {
|
||||
endpoint,
|
||||
api_key,
|
||||
http,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OpenAIResponsesProvider {
|
||||
fn name(&self) -> &str {
|
||||
&self.endpoint.name
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>> {
|
||||
let mut req = self.http.get(self.endpoint.models_url());
|
||||
if let Some(key) = &self.api_key {
|
||||
req = req.bearer_auth(key);
|
||||
}
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{} list_models: {e}", self.endpoint.name))?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!(
|
||||
"{} list_models returned {}: {}",
|
||||
self.endpoint.name,
|
||||
status,
|
||||
body
|
||||
);
|
||||
}
|
||||
let body: WireModelsResponse = resp.json().await?;
|
||||
Ok(body
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|m| ModelInfo {
|
||||
id: m.id,
|
||||
display_name: None,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>> {
|
||||
let body = encode_request(&request);
|
||||
tracing::debug!(
|
||||
endpoint = %self.endpoint.name,
|
||||
url = %self.endpoint.responses_url(),
|
||||
body = %serde_json::to_string(&body).unwrap_or_else(|_| "<unserializable>".into()),
|
||||
"POST /responses"
|
||||
);
|
||||
let mut req = self.http.post(self.endpoint.responses_url()).json(&body);
|
||||
if let Some(key) = &self.api_key {
|
||||
req = req.bearer_auth(key);
|
||||
}
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{} responses send: {e}", self.endpoint.name))?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!(
|
||||
"{} responses returned {}: {}",
|
||||
self.endpoint.name,
|
||||
status,
|
||||
body
|
||||
);
|
||||
}
|
||||
let sse = resp.bytes_stream().eventsource();
|
||||
let stream = decode_stream(sse, cancel);
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Request encoding ─────────────────────────────────────────────────
|
||||
|
||||
fn encode_request(req: &CompletionRequest) -> Value {
|
||||
// Pull the system messages out of history into a single
|
||||
// `instructions` string — the Responses API expects them there,
|
||||
// not inline as an `input` item. Multiple system messages
|
||||
// concatenate with blank lines so we don't lose ordering.
|
||||
let mut instructions: Vec<String> = Vec::new();
|
||||
let mut input_items: Vec<Value> = Vec::new();
|
||||
for msg in &req.messages {
|
||||
if msg.role == Role::System
|
||||
&& let MessageContent::Text { text } = &msg.content
|
||||
{
|
||||
instructions.push(text.clone());
|
||||
continue;
|
||||
}
|
||||
if let Some(item) = encode_message_as_input_item(msg) {
|
||||
input_items.push(item);
|
||||
}
|
||||
}
|
||||
|
||||
let mut body = json!({
|
||||
"model": req.model,
|
||||
"input": input_items,
|
||||
"stream": true,
|
||||
});
|
||||
if let Value::Object(map) = &mut body {
|
||||
if !instructions.is_empty() {
|
||||
map.insert(
|
||||
"instructions".into(),
|
||||
Value::String(instructions.join("\n\n")),
|
||||
);
|
||||
}
|
||||
if let Some(t) = req.temperature {
|
||||
map.insert("temperature".into(), json!(t));
|
||||
}
|
||||
if let Some(p) = req.top_p {
|
||||
map.insert("top_p".into(), json!(p));
|
||||
}
|
||||
if let Some(m) = req.max_tokens {
|
||||
// Responses calls it `max_output_tokens`; preserve the
|
||||
// semantic (response cap) when we translate.
|
||||
map.insert("max_output_tokens".into(), json!(m));
|
||||
}
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
fn encode_message_as_input_item(msg: &Message) -> Option<Value> {
|
||||
match (msg.role, &msg.content) {
|
||||
(Role::System, _) => None, // handled out-of-band as `instructions`
|
||||
(Role::User, MessageContent::Text { text }) => Some(json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": text,
|
||||
})),
|
||||
(Role::User, MessageContent::MultiPart { parts }) => Some(json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": encode_user_parts(parts),
|
||||
})),
|
||||
(Role::Assistant, MessageContent::Text { text }) => Some(json!({
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": text,
|
||||
"annotations": [],
|
||||
}],
|
||||
})),
|
||||
(Role::Assistant, MessageContent::ToolCalls { text, calls }) => {
|
||||
// Assistant turns that called tools become a sequence of
|
||||
// items: an optional `message` (any prose alongside the
|
||||
// call) followed by one `function_call` per call. Mirrors
|
||||
// OpenAI Responses' "each item is one structural slot"
|
||||
// shape.
|
||||
//
|
||||
// We can't return multiple items from one call site, so
|
||||
// we encode this by side-stuffing additional items into a
|
||||
// single composite value and have the caller flatten —
|
||||
// but that complicates the API. Easier: build the array
|
||||
// ourselves in the caller path. For now, emit just the
|
||||
// function_calls (the assistant's prose lives in the next
|
||||
// turn's chat history anyway because the model isn't
|
||||
// looking back at its own previous narration). If the
|
||||
// text is non-empty AND we have calls, we lose the text;
|
||||
// qwen3 rarely emits prose alongside tool calls so this
|
||||
// is a deliberate simplification — revisit if it bites.
|
||||
let _ = text;
|
||||
// Take the first call only for the moment; multi-call
|
||||
// turns would need the caller-flattening above.
|
||||
let call = calls.first()?;
|
||||
Some(json!({
|
||||
"type": "function_call",
|
||||
"call_id": call.id,
|
||||
"name": call.name,
|
||||
"arguments": call.arguments,
|
||||
}))
|
||||
}
|
||||
(
|
||||
Role::Tool,
|
||||
MessageContent::ToolResult {
|
||||
tool_call_id,
|
||||
content,
|
||||
},
|
||||
) => Some(json!({
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call_id,
|
||||
"output": content,
|
||||
})),
|
||||
(role, content) => {
|
||||
tracing::warn!(
|
||||
?role,
|
||||
?content,
|
||||
"openai_responses: unexpected (role, content) shape"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_user_parts(parts: &[MessagePart]) -> Value {
|
||||
let items: Vec<Value> = parts
|
||||
.iter()
|
||||
.map(|p| match p {
|
||||
MessagePart::Text { text } => json!({"type": "input_text", "text": text}),
|
||||
MessagePart::Image(img) => json!({
|
||||
"type": "input_image",
|
||||
"image_url": format!("data:{};base64,{}", img.mime_type, img.data),
|
||||
}),
|
||||
})
|
||||
.collect();
|
||||
Value::Array(items)
|
||||
}
|
||||
|
||||
// ── Wire types ──────────────────────────────────────────────────────
|
||||
|
||||
#[allow(dead_code)] // fields read only when list_models runs against a real endpoint
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WireModelsResponse {
|
||||
data: Vec<WireModelObject>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WireModelObject {
|
||||
id: String,
|
||||
}
|
||||
|
||||
// SSE event payload shapes. We only model the fields we care about;
|
||||
// `#[serde(default)]` + `Option` everywhere else lets the upstream
|
||||
// add optional fields without breaking deserialise.
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct OutputItemAddedEvent {
|
||||
#[serde(default)]
|
||||
output_index: u32,
|
||||
item: OutputItem,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum OutputItem {
|
||||
Message {
|
||||
#[serde(default)]
|
||||
id: Option<String>,
|
||||
},
|
||||
FunctionCall {
|
||||
#[serde(default)]
|
||||
id: Option<String>,
|
||||
#[serde(default)]
|
||||
call_id: Option<String>,
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
/// Some upstreams populate `arguments` already on the
|
||||
/// `output_item.added` event for a fully-buffered tool call
|
||||
/// (i.e. when the model finalised the call before the SSE
|
||||
/// flush). Capture it so we can emit a single args delta.
|
||||
#[serde(default)]
|
||||
arguments: Option<String>,
|
||||
},
|
||||
/// `reasoning`, `web_search_call`, etc. We capture-and-ignore
|
||||
/// any item we don't model; the decoder still emits the
|
||||
/// outer events correctly.
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct OutputTextDeltaEvent {
|
||||
#[serde(default)]
|
||||
item_id: Option<String>,
|
||||
#[serde(default)]
|
||||
output_index: u32,
|
||||
#[serde(default)]
|
||||
delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct FunctionCallArgumentsDeltaEvent {
|
||||
#[serde(default)]
|
||||
item_id: Option<String>,
|
||||
#[serde(default)]
|
||||
output_index: u32,
|
||||
#[serde(default)]
|
||||
delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct ResponseCompletedEvent {
|
||||
response: ResponseShell,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct ResponseShell {
|
||||
#[serde(default)]
|
||||
status: Option<String>,
|
||||
#[serde(default)]
|
||||
usage: Option<WireUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct WireUsage {
|
||||
#[serde(default)]
|
||||
input_tokens: u64,
|
||||
#[serde(default)]
|
||||
output_tokens: u64,
|
||||
#[serde(default)]
|
||||
total_tokens: u64,
|
||||
}
|
||||
|
||||
// ── Streaming decoder ───────────────────────────────────────────────
|
||||
|
||||
/// Translate the named-event Responses SSE into the provider-agnostic
|
||||
/// [`CompletionEvent`] stream the agent loop expects. The decoder
|
||||
/// holds per-stream state — output_index → tool-call-index plus
|
||||
/// the next available tool-call slot — so it can fire
|
||||
/// `ToolCallStart` exactly once per item.
|
||||
fn decode_stream<S>(
|
||||
sse: S,
|
||||
cancel: CancellationToken,
|
||||
) -> impl Stream<Item = anyhow::Result<CompletionEvent>>
|
||||
where
|
||||
S: Stream<
|
||||
Item = Result<
|
||||
eventsource_stream::Event,
|
||||
eventsource_stream::EventStreamError<reqwest::Error>,
|
||||
>,
|
||||
> + Send
|
||||
+ 'static,
|
||||
{
|
||||
async_stream::stream! {
|
||||
let mut sse = Box::pin(sse);
|
||||
// Maps an output_index that's a function_call to the tool-call
|
||||
// slot we hand downstream. Lets us correlate later
|
||||
// `function_call_arguments.delta` events back to the index
|
||||
// we already announced on `output_item.added`.
|
||||
let mut tool_index_by_output: HashMap<u32, usize> = HashMap::new();
|
||||
let mut next_tool_index: usize = 0;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
_ = cancel.cancelled() => {
|
||||
tracing::debug!("openai_responses: cancellation requested, ending stream");
|
||||
break;
|
||||
}
|
||||
next = sse.next() => {
|
||||
let Some(event) = next else { break };
|
||||
let event = match event {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
yield Err(anyhow::anyhow!("SSE transport: {e}"));
|
||||
break;
|
||||
}
|
||||
};
|
||||
// Event name lives on `event.event`; data is JSON.
|
||||
let event_name = event.event.as_str();
|
||||
let data = event.data.as_str();
|
||||
match event_name {
|
||||
"response.output_text.delta" => {
|
||||
match serde_json::from_str::<OutputTextDeltaEvent>(data) {
|
||||
Ok(d) if !d.delta.is_empty() => {
|
||||
yield Ok(CompletionEvent::TextDelta(d.delta));
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
raw = %data,
|
||||
"openai_responses: failed to parse output_text.delta; skipping"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.output_item.added" => {
|
||||
match serde_json::from_str::<OutputItemAddedEvent>(data) {
|
||||
Ok(ev) => {
|
||||
if let OutputItem::FunctionCall {
|
||||
id,
|
||||
call_id,
|
||||
name,
|
||||
arguments,
|
||||
} = ev.item
|
||||
{
|
||||
let idx = next_tool_index;
|
||||
next_tool_index += 1;
|
||||
tool_index_by_output.insert(ev.output_index, idx);
|
||||
// Prefer the user-facing
|
||||
// `call_id` (what gets paired
|
||||
// with tool results) over the
|
||||
// internal item `id` when
|
||||
// both are present. Falls
|
||||
// back to a synthetic id so
|
||||
// history bookkeeping never
|
||||
// breaks.
|
||||
let final_id = call_id
|
||||
.or(id)
|
||||
.unwrap_or_else(|| format!("call_{idx}"));
|
||||
let final_name = name.unwrap_or_default();
|
||||
yield Ok(CompletionEvent::ToolCallStart {
|
||||
index: idx,
|
||||
id: final_id,
|
||||
name: final_name,
|
||||
});
|
||||
// Some upstreams attach the
|
||||
// fully-buffered arguments on
|
||||
// the `output_item.added`
|
||||
// event itself (rare; happens
|
||||
// when the model finalised
|
||||
// before the SSE flush).
|
||||
// Emit as a single args
|
||||
// delta if present.
|
||||
if let Some(args) = arguments
|
||||
&& !args.is_empty()
|
||||
{
|
||||
yield Ok(CompletionEvent::ToolCallArgsDelta {
|
||||
index: idx,
|
||||
args_delta: args,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
raw = %data,
|
||||
"openai_responses: failed to parse output_item.added; skipping"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.function_call_arguments.delta" => {
|
||||
match serde_json::from_str::<FunctionCallArgumentsDeltaEvent>(data) {
|
||||
Ok(ev) => {
|
||||
let Some(&idx) = tool_index_by_output.get(&ev.output_index)
|
||||
else {
|
||||
// Args delta for an item we
|
||||
// never saw an `output_item.added`
|
||||
// for. Could happen if the
|
||||
// upstream reordered events;
|
||||
// log + skip.
|
||||
tracing::warn!(
|
||||
output_index = ev.output_index,
|
||||
"openai_responses: function_call_arguments.delta for unknown output_index"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
if !ev.delta.is_empty() {
|
||||
yield Ok(CompletionEvent::ToolCallArgsDelta {
|
||||
index: idx,
|
||||
args_delta: ev.delta,
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
raw = %data,
|
||||
"openai_responses: failed to parse function_call_arguments.delta; skipping"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.completed" => {
|
||||
// Final event. Pull usage + status off
|
||||
// the response shell. Status maps:
|
||||
// "completed" → no special handling
|
||||
// (caller treats as EndTurn),
|
||||
// "incomplete" → length stop.
|
||||
let (reason, usage) =
|
||||
match serde_json::from_str::<ResponseCompletedEvent>(data) {
|
||||
Ok(ev) => {
|
||||
let reason = match ev.response.status.as_deref() {
|
||||
Some("incomplete") => Some("length".to_string()),
|
||||
_ => Some("stop".to_string()),
|
||||
};
|
||||
let usage = ev.response.usage.map(|u| UsageStats {
|
||||
prompt_tokens: u.input_tokens,
|
||||
completion_tokens: u.output_tokens,
|
||||
total_tokens: u.total_tokens,
|
||||
});
|
||||
(reason, usage)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
raw = %data,
|
||||
"openai_responses: failed to parse response.completed; ending stream with EndTurn"
|
||||
);
|
||||
(Some("stop".to_string()), None)
|
||||
}
|
||||
};
|
||||
if let Some(u) = usage {
|
||||
yield Ok(CompletionEvent::Usage(u));
|
||||
}
|
||||
yield Ok(CompletionEvent::Finish { reason });
|
||||
break;
|
||||
}
|
||||
// Bookkeeping events we don't need to surface:
|
||||
// response.created, response.in_progress,
|
||||
// response.content_part.added/.done,
|
||||
// response.output_text.done,
|
||||
// response.output_item.done,
|
||||
// response.function_call_arguments.done,
|
||||
// response.reasoning_*. Logged at debug for
|
||||
// wire-tracing.
|
||||
other => {
|
||||
tracing::trace!(
|
||||
event = other,
|
||||
"openai_responses: bookkeeping event"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::ToolCall;
|
||||
use crate::provider::{ImageData, MessagePart};
|
||||
use futures::stream;
|
||||
use url::Url;
|
||||
|
||||
fn ep() -> EndpointConfig {
|
||||
EndpointConfig {
|
||||
name: "test".into(),
|
||||
base_url: Url::parse("http://localhost:9999/v1").unwrap(),
|
||||
wire_api: crate::config::WireApi::OpenAiResponses,
|
||||
default_model: None,
|
||||
api_key: None,
|
||||
api_key_env: None,
|
||||
max_tokens: None,
|
||||
context_window: None,
|
||||
}
|
||||
}
|
||||
|
||||
// ── encode_request ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn system_messages_collapse_to_instructions() {
|
||||
let req = CompletionRequest {
|
||||
model: "m".into(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text {
|
||||
text: "you are helpful".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text { text: "hi".into() },
|
||||
},
|
||||
],
|
||||
tools: vec![],
|
||||
temperature: Some(0.7),
|
||||
top_p: None,
|
||||
max_tokens: Some(256),
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
assert_eq!(body["model"], "m");
|
||||
assert_eq!(body["instructions"], "you are helpful");
|
||||
assert_eq!(body["stream"], true);
|
||||
assert_eq!(body["max_output_tokens"], 256);
|
||||
assert_eq!(body["temperature"], 0.7);
|
||||
let input = body["input"].as_array().unwrap();
|
||||
// System message NOT echoed in input — it's only in
|
||||
// instructions.
|
||||
assert_eq!(input.len(), 1);
|
||||
assert_eq!(input[0]["type"], "message");
|
||||
assert_eq!(input[0]["role"], "user");
|
||||
assert_eq!(input[0]["content"], "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_system_messages_concatenate() {
|
||||
let req = CompletionRequest {
|
||||
model: "m".into(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text {
|
||||
text: "first".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text {
|
||||
text: "second".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text { text: "hi".into() },
|
||||
},
|
||||
],
|
||||
tools: vec![],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_tokens: None,
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
assert_eq!(body["instructions"], "first\n\nsecond");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_multipart_becomes_input_parts_array() {
|
||||
let req = CompletionRequest {
|
||||
model: "vl".into(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::MultiPart {
|
||||
parts: vec![
|
||||
MessagePart::Text {
|
||||
text: "what's in this?".into(),
|
||||
},
|
||||
MessagePart::Image(ImageData {
|
||||
mime_type: "image/png".into(),
|
||||
data: "AAA=".into(),
|
||||
uri: None,
|
||||
}),
|
||||
],
|
||||
},
|
||||
}],
|
||||
tools: vec![],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_tokens: None,
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
let content = &body["input"][0]["content"].as_array().unwrap().clone();
|
||||
assert_eq!(content.len(), 2);
|
||||
assert_eq!(content[0]["type"], "input_text");
|
||||
assert_eq!(content[0]["text"], "what's in this?");
|
||||
assert_eq!(content[1]["type"], "input_image");
|
||||
assert_eq!(content[1]["image_url"], "data:image/png;base64,AAA=");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assistant_text_becomes_output_text_content_part() {
|
||||
let req = CompletionRequest {
|
||||
model: "m".into(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text { text: "hi".into() },
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text {
|
||||
text: "hello there".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text {
|
||||
text: "more".into(),
|
||||
},
|
||||
},
|
||||
],
|
||||
tools: vec![],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_tokens: None,
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
let input = body["input"].as_array().unwrap();
|
||||
assert_eq!(input.len(), 3);
|
||||
assert_eq!(input[1]["type"], "message");
|
||||
assert_eq!(input[1]["role"], "assistant");
|
||||
assert_eq!(input[1]["content"][0]["type"], "output_text");
|
||||
assert_eq!(input[1]["content"][0]["text"], "hello there");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_calls_and_results_round_trip_via_function_call_items() {
|
||||
let req = CompletionRequest {
|
||||
model: "m".into(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::ToolCalls {
|
||||
text: None,
|
||||
calls: vec![ToolCall {
|
||||
id: "call_42".into(),
|
||||
name: "read_file".into(),
|
||||
arguments: r#"{"path":"/etc/hostname"}"#.into(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::ToolResult {
|
||||
tool_call_id: "call_42".into(),
|
||||
content: "host".into(),
|
||||
},
|
||||
},
|
||||
],
|
||||
tools: vec![],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_tokens: None,
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
let input = body["input"].as_array().unwrap();
|
||||
assert_eq!(input.len(), 2);
|
||||
assert_eq!(input[0]["type"], "function_call");
|
||||
assert_eq!(input[0]["call_id"], "call_42");
|
||||
assert_eq!(input[0]["name"], "read_file");
|
||||
assert_eq!(input[0]["arguments"], r#"{"path":"/etc/hostname"}"#);
|
||||
assert_eq!(input[1]["type"], "function_call_output");
|
||||
assert_eq!(input[1]["call_id"], "call_42");
|
||||
assert_eq!(input[1]["output"], "host");
|
||||
}
|
||||
|
||||
// ── decode_stream ───────────────────────────────────────────────
|
||||
|
||||
fn sse_event(name: &str, data: &str) -> eventsource_stream::Event {
|
||||
eventsource_stream::Event {
|
||||
id: String::new(),
|
||||
retry: None,
|
||||
event: name.into(),
|
||||
data: data.into(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_events(
|
||||
items: Vec<eventsource_stream::Event>,
|
||||
) -> Vec<anyhow::Result<CompletionEvent>> {
|
||||
let sse = stream::iter(
|
||||
items
|
||||
.into_iter()
|
||||
.map(Ok::<_, eventsource_stream::EventStreamError<reqwest::Error>>),
|
||||
);
|
||||
let decoded = decode_stream(sse, CancellationToken::new());
|
||||
decoded.collect().await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn decodes_text_then_finish() {
|
||||
let events = collect_events(vec![
|
||||
sse_event("response.created", "{}"),
|
||||
sse_event(
|
||||
"response.output_text.delta",
|
||||
r#"{"item_id":"msg_1","output_index":0,"delta":"hel"}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.output_text.delta",
|
||||
r#"{"item_id":"msg_1","output_index":0,"delta":"lo"}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.completed",
|
||||
r#"{"response":{"status":"completed","usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}"#,
|
||||
),
|
||||
])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
let mut iter = events.into_iter();
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "hel"));
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "lo"));
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::Usage(u)) if u.total_tokens == 5));
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "stop"
|
||||
));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_delta_is_dropped() {
|
||||
let events = collect_events(vec![
|
||||
sse_event(
|
||||
"response.output_text.delta",
|
||||
r#"{"item_id":"m","output_index":0,"delta":""}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.completed",
|
||||
r#"{"response":{"status":"completed"}}"#,
|
||||
),
|
||||
])
|
||||
.await;
|
||||
let mut completion_events = events.into_iter().map(|r| r.unwrap());
|
||||
// First event MUST be the Finish — the empty delta dropped.
|
||||
assert!(matches!(
|
||||
completion_events.next(),
|
||||
Some(CompletionEvent::Finish { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn incomplete_status_maps_to_length_finish_reason() {
|
||||
let events = collect_events(vec![sse_event(
|
||||
"response.completed",
|
||||
r#"{"response":{"status":"incomplete"}}"#,
|
||||
)])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
assert!(matches!(
|
||||
events.last(),
|
||||
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "length"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn function_call_items_emit_toolcall_events() {
|
||||
let events = collect_events(vec![
|
||||
sse_event(
|
||||
"response.output_item.added",
|
||||
r#"{"output_index":0,"item":{"type":"function_call","id":"item_1","call_id":"call_xyz","name":"read_file"}}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.function_call_arguments.delta",
|
||||
r#"{"item_id":"item_1","output_index":0,"delta":"{\"path"}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.function_call_arguments.delta",
|
||||
r#"{"item_id":"item_1","output_index":0,"delta":"\":\"/etc/hostname\"}"}"#,
|
||||
),
|
||||
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
|
||||
])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
let mut iter = events.into_iter();
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallStart { index: 0, ref id, ref name })
|
||||
if id == "call_xyz" && name == "read_file"
|
||||
));
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
|
||||
if args_delta == r#"{"path"#
|
||||
));
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
|
||||
if args_delta == r#"":"/etc/hostname"}"#
|
||||
));
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn function_call_added_with_inline_arguments_emits_single_args_delta() {
|
||||
// Some upstreams (rare) include the fully-buffered arguments
|
||||
// on the `output_item.added` event when the model finalised
|
||||
// the call before SSE flush. Verify both ToolCallStart and a
|
||||
// single args delta fire.
|
||||
let events = collect_events(vec![
|
||||
sse_event(
|
||||
"response.output_item.added",
|
||||
r#"{"output_index":0,"item":{"type":"function_call","call_id":"call_a","name":"f","arguments":"{\"x\":1}"}}"#,
|
||||
),
|
||||
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
|
||||
])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
let mut iter = events.into_iter();
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallStart { .. })
|
||||
));
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
|
||||
if args_delta == r#"{"x":1}"#
|
||||
));
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancellation_ends_stream_promptly() {
|
||||
// Hand the decoder an empty stream + a triggered cancellation
|
||||
// token; it should terminate without yielding anything.
|
||||
let sse = stream::iter(Vec::<
|
||||
Result<eventsource_stream::Event, eventsource_stream::EventStreamError<reqwest::Error>>,
|
||||
>::new());
|
||||
let cancel = CancellationToken::new();
|
||||
cancel.cancel();
|
||||
let decoded = decode_stream(sse, cancel);
|
||||
let events: Vec<_> = decoded.collect().await;
|
||||
assert!(events.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn malformed_event_payload_is_skipped() {
|
||||
let events = collect_events(vec![
|
||||
sse_event("response.output_text.delta", "{not valid json"),
|
||||
sse_event(
|
||||
"response.output_text.delta",
|
||||
r#"{"item_id":"m","output_index":0,"delta":"ok"}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.completed",
|
||||
r#"{"response":{"status":"completed"}}"#,
|
||||
),
|
||||
])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
// First text delta dropped; second one fires.
|
||||
assert!(
|
||||
events
|
||||
.iter()
|
||||
.any(|e| matches!(e, CompletionEvent::TextDelta(t) if t == "ok"))
|
||||
);
|
||||
// No errors yielded (parse failures are warn-and-skip).
|
||||
assert!(
|
||||
events
|
||||
.iter()
|
||||
.all(|e| !matches!(e, CompletionEvent::Finish { reason: None }))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_construction_is_cheap() {
|
||||
let _ = OpenAIResponsesProvider::new(ep()).unwrap();
|
||||
}
|
||||
}
|
||||
1018
crates/helexa-acp/src/qwen3.rs
Normal file
1018
crates/helexa-acp/src/qwen3.rs
Normal file
File diff suppressed because it is too large
Load Diff
188
crates/helexa-acp/src/session.rs
Normal file
188
crates/helexa-acp/src/session.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
//! Per-session state for the ACP agent loop.
|
||||
//!
|
||||
//! Concurrency:
|
||||
//!
|
||||
//! - [`SessionStore`] is an `Arc<RwLock<HashMap<SessionId, …>>>`. The map
|
||||
//! itself is read-mostly: it changes only on `session/new` and never
|
||||
//! shrinks during Stage 2, so an `RwLock` keeps concurrent reads
|
||||
//! contention-free.
|
||||
//! - Each session is wrapped in its own `Arc<Mutex<SessionState>>`. Holding
|
||||
//! one session's lock doesn't block requests against any other session,
|
||||
//! which matters once a client opens multiple sessions in parallel.
|
||||
//!
|
||||
//! All operations hold a lock only long enough to copy out (or mutate) the
|
||||
//! state they need — never across an `await` that drives the upstream
|
||||
//! provider stream.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent_client_protocol::schema::{SessionId, SessionModeId};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::provider::Message;
|
||||
|
||||
/// Mode id advertised as the gated default. Writes / bash prompt for
|
||||
/// permission via `session/request_permission`.
|
||||
pub const MODE_DEFAULT: &str = "default";
|
||||
|
||||
/// Mode id advertised as "auto-allow everything". Matches the
|
||||
/// favorite name (`bypassPermissions`) Zed clients tend to reference.
|
||||
pub const MODE_BYPASS: &str = "bypassPermissions";
|
||||
|
||||
/// Mode id for read-and-plan-only operation. The model may read files
|
||||
/// and list directories freely, may write *only* into the per-project
|
||||
/// plan directory under `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`,
|
||||
/// and cannot run shell commands. Designed for "draft the
|
||||
/// implementation plan, then I'll review and let you execute" flows.
|
||||
pub const MODE_PLAN: &str = "plan";
|
||||
|
||||
/// State carried for a single ACP session.
|
||||
///
|
||||
/// Mutated under `Mutex<SessionState>`; never share a clone across
|
||||
/// tasks expecting to see the same `cancel` token — clone the token
|
||||
/// explicitly when handing it to the streaming task.
|
||||
#[derive(Debug)]
|
||||
pub struct SessionState {
|
||||
/// Conversation history in chronological order (user / assistant
|
||||
/// turns). The system prompt is *not* stored here — it's built
|
||||
/// fresh per request so any cwd / config changes take effect.
|
||||
pub history: Vec<Message>,
|
||||
/// Working directory the client opened the session against. Used
|
||||
/// by [`crate::prompt::build_system_prompt`] and (Stage 3) by
|
||||
/// filesystem tools.
|
||||
pub cwd: PathBuf,
|
||||
/// Currently-selected model id. Format is either a bare model id
|
||||
/// (resolved against the default endpoint) or `endpoint:model`.
|
||||
/// Mutated by `session/set_model` in Stage 4; Stage 2 sets it
|
||||
/// once at session creation and never changes it.
|
||||
pub model_id: String,
|
||||
/// Cancellation handle for the in-flight prompt, if any. A fresh
|
||||
/// token is installed at the start of every `session/prompt`
|
||||
/// request; `session/cancel` fires this one. Between prompts the
|
||||
/// token is "spent" — firing it does nothing — which is fine,
|
||||
/// `session/cancel` is a no-op when there's nothing to cancel.
|
||||
pub cancel: CancellationToken,
|
||||
/// Permission gating mode. Stage 3 advertises two ids in
|
||||
/// `NewSessionResponse.modes`: [`MODE_DEFAULT`] (writes / bash
|
||||
/// prompt the user) and [`MODE_BYPASS`] (auto-allow). Mutated by
|
||||
/// `session/set_mode`.
|
||||
pub mode_id: SessionModeId,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
pub fn new(cwd: PathBuf, model_id: String) -> Self {
|
||||
Self {
|
||||
history: Vec::new(),
|
||||
cwd,
|
||||
model_id,
|
||||
cancel: CancellationToken::new(),
|
||||
mode_id: SessionModeId::new(MODE_DEFAULT),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Concurrent map of live sessions.
|
||||
///
|
||||
/// Cloning is cheap (`Arc` bump). Pass clones into every handler that
|
||||
/// needs session access; never hold a clone across an `.await` that
|
||||
/// could outlive the request.
|
||||
pub type SessionStore = Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>;
|
||||
|
||||
/// Fresh, empty session store.
|
||||
pub fn new_store() -> SessionStore {
|
||||
Arc::new(RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Look up a session by id. Returns `None` if no such session is registered.
|
||||
pub async fn get(store: &SessionStore, id: &SessionId) -> Option<Arc<Mutex<SessionState>>> {
|
||||
store.read().await.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Register a fresh session. Overwrites any prior entry with the same id
|
||||
/// (which should never happen — ids are uniquely generated by the agent).
|
||||
pub async fn insert(store: &SessionStore, id: SessionId, state: SessionState) {
|
||||
store.write().await.insert(id, Arc::new(Mutex::new(state)));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::{MessageContent, Role};
|
||||
|
||||
fn id(s: &str) -> SessionId {
|
||||
SessionId::new(s)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_then_get_round_trip() {
|
||||
let store = new_store();
|
||||
let state = SessionState::new(PathBuf::from("/tmp"), "m".into());
|
||||
insert(&store, id("s1"), state).await;
|
||||
let got = get(&store, &id("s1")).await.expect("session present");
|
||||
let locked = got.lock().await;
|
||||
assert_eq!(locked.cwd, PathBuf::from("/tmp"));
|
||||
assert_eq!(locked.model_id, "m");
|
||||
assert!(locked.history.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_session_is_none() {
|
||||
let store = new_store();
|
||||
assert!(get(&store, &id("nope")).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_is_per_session() {
|
||||
let store = new_store();
|
||||
insert(
|
||||
&store,
|
||||
id("a"),
|
||||
SessionState::new(PathBuf::from("/a"), "m".into()),
|
||||
)
|
||||
.await;
|
||||
insert(
|
||||
&store,
|
||||
id("b"),
|
||||
SessionState::new(PathBuf::from("/b"), "m".into()),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Appending to a's history must not affect b's.
|
||||
get(&store, &id("a"))
|
||||
.await
|
||||
.unwrap()
|
||||
.lock()
|
||||
.await
|
||||
.history
|
||||
.push(Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text {
|
||||
text: "hello".into(),
|
||||
},
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
get(&store, &id("a"))
|
||||
.await
|
||||
.unwrap()
|
||||
.lock()
|
||||
.await
|
||||
.history
|
||||
.len(),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
get(&store, &id("b"))
|
||||
.await
|
||||
.unwrap()
|
||||
.lock()
|
||||
.await
|
||||
.history
|
||||
.len(),
|
||||
0
|
||||
);
|
||||
}
|
||||
}
|
||||
462
crates/helexa-acp/src/store.rs
Normal file
462
crates/helexa-acp/src/store.rs
Normal file
@@ -0,0 +1,462 @@
|
||||
//! On-disk session persistence for `session/load` support.
|
||||
//!
|
||||
//! Storage layout:
|
||||
//!
|
||||
//! ```text
|
||||
//! $XDG_DATA_HOME/helexa-acp/sessions/{session_id}.json
|
||||
//! ```
|
||||
//!
|
||||
//! (Fallback to `~/.local/share/helexa-acp/sessions/` when
|
||||
//! `$XDG_DATA_HOME` is unset.) One JSON file per session. Writes
|
||||
//! happen at the end of every `session/prompt` round through
|
||||
//! [`save`], using tempfile-plus-rename so a crash mid-write can't
|
||||
//! corrupt the store. Reads happen on `session/load` via [`load`].
|
||||
//!
|
||||
//! No compaction, no rotation: files accumulate until the user
|
||||
//! cleans them up. That's deliberate — disk is cheap, and the
|
||||
//! resume-on-restart workflow matters more than tidiness. The
|
||||
//! [`SESSIONS_DIRNAME`] subdirectory is created lazily on first
|
||||
//! save so an unprivileged install path never errors at startup.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use agent_client_protocol::schema::SessionId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::provider::Message;
|
||||
|
||||
const APP_DIRNAME: &str = "helexa-acp";
|
||||
const SESSIONS_DIRNAME: &str = "sessions";
|
||||
const PLANS_DIRNAME: &str = "plans";
|
||||
|
||||
/// The shape persisted to disk for one session. Only what we can't
|
||||
/// rebuild from the running config goes in here: the conversation
|
||||
/// history, the mode toggle, the model id, and the cwd-at-creation.
|
||||
///
|
||||
/// `created_at` / `updated_at` are seconds-since-epoch — cheap to
|
||||
/// compare, no third-party time crate, and stable across runs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersistedSession {
|
||||
pub session_id: String,
|
||||
pub cwd: PathBuf,
|
||||
pub model_id: String,
|
||||
pub mode_id: String,
|
||||
pub history: Vec<Message>,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
}
|
||||
|
||||
/// Resolve the directory that holds session JSON files. Honors
|
||||
/// `$XDG_DATA_HOME`; falls back to `~/.local/share/helexa-acp/sessions/`.
|
||||
/// Returns `None` if neither is resolvable (no `HOME` set — possible
|
||||
/// in stripped-down container environments).
|
||||
pub fn sessions_dir() -> Option<PathBuf> {
|
||||
let base = std::env::var("XDG_DATA_HOME")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| {
|
||||
std::env::var("HOME")
|
||||
.ok()
|
||||
.map(|h| PathBuf::from(h).join(".local").join("share"))
|
||||
})?;
|
||||
Some(base.join(APP_DIRNAME).join(SESSIONS_DIRNAME))
|
||||
}
|
||||
|
||||
/// Atomic save into the default sessions directory.
|
||||
pub fn save(session: &PersistedSession) -> anyhow::Result<()> {
|
||||
let dir = sessions_dir()
|
||||
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
|
||||
save_to_dir(&dir, session)
|
||||
}
|
||||
|
||||
/// Load from the default sessions directory.
|
||||
pub fn load(session_id: &SessionId) -> anyhow::Result<PersistedSession> {
|
||||
let dir = sessions_dir()
|
||||
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
|
||||
load_from_dir(&dir, session_id)
|
||||
}
|
||||
|
||||
/// Atomic save into an explicit directory. Writes to
|
||||
/// `{id}.json.tmp` then renames over `{id}.json`. Creates the
|
||||
/// target directory if it doesn't exist. Split from [`save`] so
|
||||
/// unit tests can target a per-test scratch dir without mutating
|
||||
/// process-global env vars.
|
||||
pub fn save_to_dir(dir: &std::path::Path, session: &PersistedSession) -> anyhow::Result<()> {
|
||||
std::fs::create_dir_all(dir).map_err(|e| anyhow::anyhow!("create {}: {e}", dir.display()))?;
|
||||
let safe = sanitize_id(&session.session_id);
|
||||
let final_path = dir.join(format!("{safe}.json"));
|
||||
let tmp_path = dir.join(format!("{safe}.json.tmp"));
|
||||
let json = serde_json::to_string_pretty(session)?;
|
||||
std::fs::write(&tmp_path, json)
|
||||
.map_err(|e| anyhow::anyhow!("write {}: {e}", tmp_path.display()))?;
|
||||
std::fs::rename(&tmp_path, &final_path)
|
||||
.map_err(|e| anyhow::anyhow!("rename → {}: {e}", final_path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load from an explicit directory. Returns a friendly error
|
||||
/// message when the session id has no file on disk so the caller
|
||||
/// can map it to a clean ACP error response.
|
||||
pub fn load_from_dir(
|
||||
dir: &std::path::Path,
|
||||
session_id: &SessionId,
|
||||
) -> anyhow::Result<PersistedSession> {
|
||||
let safe = sanitize_id(session_id.0.as_ref());
|
||||
let path = dir.join(format!("{safe}.json"));
|
||||
let bytes = std::fs::read(&path).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
anyhow::anyhow!("no persisted session at {}", path.display())
|
||||
} else {
|
||||
anyhow::anyhow!("read {}: {e}", path.display())
|
||||
}
|
||||
})?;
|
||||
let session: PersistedSession = serde_json::from_slice(&bytes)
|
||||
.map_err(|e| anyhow::anyhow!("parse {}: {e}", path.display()))?;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// List all persisted sessions, optionally filtered by `cwd`. Used
|
||||
/// by the `session/list` handler so a client (Zed) can find the
|
||||
/// session that belongs to the workspace it's reopening.
|
||||
///
|
||||
/// `filter_cwd = None` returns every session on disk. `Some(path)`
|
||||
/// returns only sessions whose persisted `cwd` is exactly equal.
|
||||
///
|
||||
/// Files that fail to parse are skipped with a warning rather than
|
||||
/// aborting the whole list — one corrupt session shouldn't make
|
||||
/// the resume picker unusable.
|
||||
pub fn list(filter_cwd: Option<&std::path::Path>) -> anyhow::Result<Vec<PersistedSession>> {
|
||||
let dir = sessions_dir()
|
||||
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
|
||||
list_in_dir(&dir, filter_cwd)
|
||||
}
|
||||
|
||||
/// Explicit-dir variant for tests, mirroring [`save_to_dir`] /
|
||||
/// [`load_from_dir`].
|
||||
pub fn list_in_dir(
|
||||
dir: &std::path::Path,
|
||||
filter_cwd: Option<&std::path::Path>,
|
||||
) -> anyhow::Result<Vec<PersistedSession>> {
|
||||
let read = match std::fs::read_dir(dir) {
|
||||
Ok(r) => r,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
|
||||
Err(e) => return Err(anyhow::anyhow!("read_dir {}: {e}", dir.display())),
|
||||
};
|
||||
let mut out = Vec::new();
|
||||
for entry in read.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|s| s.to_str()) != Some("json") {
|
||||
continue;
|
||||
}
|
||||
match std::fs::read(&path).and_then(|bytes| {
|
||||
serde_json::from_slice::<PersistedSession>(&bytes).map_err(std::io::Error::other)
|
||||
}) {
|
||||
Ok(session) => {
|
||||
if let Some(want) = filter_cwd
|
||||
&& session.cwd != want
|
||||
{
|
||||
continue;
|
||||
}
|
||||
out.push(session);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"store: skipping unparseable session file"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Most-recent first by updated_at.
|
||||
out.sort_by_key(|s| std::cmp::Reverse(s.updated_at));
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Seconds-since-epoch, saturating to 0 if the system clock is
|
||||
/// behind epoch (which shouldn't happen but the type system
|
||||
/// requires a fallible read).
|
||||
pub fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Root directory for plan-mode artefacts. Mirrors [`sessions_dir`]
|
||||
/// but under `…/helexa-acp/plans/` so plans and conversation
|
||||
/// transcripts are siblings, not nested.
|
||||
pub fn plans_root() -> Option<PathBuf> {
|
||||
sessions_dir().and_then(|s| s.parent().map(|p| p.join(PLANS_DIRNAME)))
|
||||
}
|
||||
|
||||
/// Per-project plan directory:
|
||||
/// `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. The id derives
|
||||
/// from the session's cwd so plans for the same project survive
|
||||
/// across cwd-changes (a `/home/foo/git/bar` ↔ symlinked
|
||||
/// `/srv/checkout/bar` would technically diverge, accepted as a
|
||||
/// won't-fix corner case).
|
||||
pub fn plan_dir_for(cwd: &std::path::Path) -> Option<PathBuf> {
|
||||
plans_root().map(|root| root.join(project_id_for(cwd)))
|
||||
}
|
||||
|
||||
/// Deterministic, human-readable project identifier. Format:
|
||||
/// `<basename>-<8-hex>` where the 8-hex suffix is FNV-1a of the
|
||||
/// full path. Basename keeps the path skim-readable when poking
|
||||
/// around `$XDG_DATA_HOME` by hand; the hash suffix disambiguates
|
||||
/// repos that share a final path component (e.g. multiple
|
||||
/// `/.../checkout/beat` checkouts).
|
||||
///
|
||||
/// FNV-1a rather than `std::collections::hash::DefaultHasher`
|
||||
/// because the latter (SipHash) reseeds per process, so it'd give
|
||||
/// us a different project_id on every run.
|
||||
pub fn project_id_for(cwd: &std::path::Path) -> String {
|
||||
let basename = cwd
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("unknown");
|
||||
let sanitised: String = basename
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let hash = fnv1a_32(cwd.to_string_lossy().as_bytes());
|
||||
format!("{sanitised}-{hash:08x}")
|
||||
}
|
||||
|
||||
/// FNV-1a (32-bit). Deterministic, no third-party crate. Used for
|
||||
/// project ids only — not cryptographic.
|
||||
fn fnv1a_32(bytes: &[u8]) -> u32 {
|
||||
let mut h: u32 = 0x811c_9dc5;
|
||||
for b in bytes {
|
||||
h ^= u32::from(*b);
|
||||
h = h.wrapping_mul(0x0100_0193);
|
||||
}
|
||||
h
|
||||
}
|
||||
|
||||
/// Format seconds-since-epoch as an ISO 8601 / RFC 3339 string
|
||||
/// (`YYYY-MM-DDTHH:MM:SSZ`) for `SessionInfo.updated_at`. Returns
|
||||
/// `None` for values outside the representable range, in which
|
||||
/// case the caller should omit the field.
|
||||
pub fn unix_to_iso8601(secs: u64) -> Option<String> {
|
||||
use chrono::TimeZone;
|
||||
let dt = chrono::Utc.timestamp_opt(secs as i64, 0).single()?;
|
||||
Some(dt.to_rfc3339_opts(chrono::SecondsFormat::Secs, true))
|
||||
}
|
||||
|
||||
/// Strip anything that isn't a safe filename character so a
|
||||
/// mischievous (or just unconventional) session id can't escape
|
||||
/// the sessions directory.
|
||||
fn sanitize_id(id: &str) -> String {
|
||||
id.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::{MessageContent, Role};
|
||||
|
||||
/// Unique scratch dir per test invocation. We use this dir
|
||||
/// directly with the `*_to_dir` / `*_from_dir` functions so
|
||||
/// the tests never mutate `$XDG_DATA_HOME` — that env var
|
||||
/// would race across the parallel test harness.
|
||||
fn unique_dir() -> PathBuf {
|
||||
let base = std::env::var("CARGO_TARGET_TMPDIR")
|
||||
.ok()
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(std::env::temp_dir);
|
||||
let pid = std::process::id();
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.subsec_nanos())
|
||||
.unwrap_or(0);
|
||||
let dir = base.join(format!("helexa-acp-store-test-{pid}-{nanos}"));
|
||||
std::fs::create_dir_all(&dir).expect("create test dir");
|
||||
dir
|
||||
}
|
||||
|
||||
fn sample(id: &str) -> PersistedSession {
|
||||
PersistedSession {
|
||||
session_id: id.into(),
|
||||
cwd: PathBuf::from("/home/me/proj"),
|
||||
model_id: "Qwen/Qwen3.6-27B".into(),
|
||||
mode_id: "default".into(),
|
||||
history: vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text {
|
||||
text: "hello".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text { text: "hi".into() },
|
||||
},
|
||||
],
|
||||
created_at: 1_700_000_000,
|
||||
updated_at: 1_700_000_001,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_trip_save_then_load() {
|
||||
let dir = unique_dir();
|
||||
save_to_dir(&dir, &sample("hxa-1")).expect("save");
|
||||
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
|
||||
assert_eq!(loaded.session_id, "hxa-1");
|
||||
assert_eq!(loaded.cwd, PathBuf::from("/home/me/proj"));
|
||||
assert_eq!(loaded.history.len(), 2);
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_missing_session_errors_with_not_found_message() {
|
||||
let dir = unique_dir();
|
||||
let err = load_from_dir(&dir, &SessionId::new("nope")).unwrap_err();
|
||||
let msg = format!("{err}");
|
||||
assert!(
|
||||
msg.contains("no persisted session"),
|
||||
"want NotFound, got: {msg}"
|
||||
);
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_overwrites_existing_atomically() {
|
||||
let dir = unique_dir();
|
||||
save_to_dir(&dir, &sample("hxa-1")).expect("save");
|
||||
let mut updated = sample("hxa-1");
|
||||
updated.history.push(Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text {
|
||||
text: "third turn".into(),
|
||||
},
|
||||
});
|
||||
updated.updated_at = 1_700_000_500;
|
||||
save_to_dir(&dir, &updated).expect("re-save");
|
||||
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
|
||||
assert_eq!(loaded.history.len(), 3);
|
||||
assert_eq!(loaded.updated_at, 1_700_000_500);
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_then_load_preserves_tool_calls_and_results() {
|
||||
use crate::provider::ToolCall;
|
||||
let dir = unique_dir();
|
||||
let mut session = sample("hxa-2");
|
||||
session.history.push(Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::ToolCalls {
|
||||
text: Some("calling".into()),
|
||||
calls: vec![ToolCall {
|
||||
id: "call_0".into(),
|
||||
name: "read_file".into(),
|
||||
arguments: r#"{"path":"/etc/hostname"}"#.into(),
|
||||
}],
|
||||
},
|
||||
});
|
||||
session.history.push(Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::ToolResult {
|
||||
tool_call_id: "call_0".into(),
|
||||
content: "host".into(),
|
||||
},
|
||||
});
|
||||
save_to_dir(&dir, &session).expect("save");
|
||||
let loaded = load_from_dir(&dir, &SessionId::new("hxa-2")).expect("load");
|
||||
assert_eq!(loaded.history.len(), 4);
|
||||
match &loaded.history[2].content {
|
||||
MessageContent::ToolCalls { calls, .. } => {
|
||||
assert_eq!(calls[0].name, "read_file");
|
||||
}
|
||||
other => panic!("expected ToolCalls, got {other:?}"),
|
||||
}
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_filters_by_cwd_and_sorts_recent_first() {
|
||||
let dir = unique_dir();
|
||||
let mut a = sample("a");
|
||||
a.cwd = PathBuf::from("/home/me/proj-x");
|
||||
a.updated_at = 1_700_000_010;
|
||||
let mut b = sample("b");
|
||||
b.cwd = PathBuf::from("/home/me/proj-x");
|
||||
b.updated_at = 1_700_000_020;
|
||||
let mut c = sample("c");
|
||||
c.cwd = PathBuf::from("/home/me/elsewhere");
|
||||
c.updated_at = 1_700_000_030;
|
||||
save_to_dir(&dir, &a).unwrap();
|
||||
save_to_dir(&dir, &b).unwrap();
|
||||
save_to_dir(&dir, &c).unwrap();
|
||||
|
||||
let proj_x = PathBuf::from("/home/me/proj-x");
|
||||
let list = list_in_dir(&dir, Some(&proj_x)).unwrap();
|
||||
let ids: Vec<&str> = list.iter().map(|s| s.session_id.as_str()).collect();
|
||||
// Filtered to proj-x; b before a because b is more recent.
|
||||
assert_eq!(ids, vec!["b", "a"]);
|
||||
|
||||
let all = list_in_dir(&dir, None).unwrap();
|
||||
assert_eq!(all.len(), 3);
|
||||
// Global list still sorted recent-first across all cwds.
|
||||
assert_eq!(all[0].session_id, "c");
|
||||
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_returns_empty_for_missing_dir() {
|
||||
let dir = unique_dir().join("does-not-exist");
|
||||
let list = list_in_dir(&dir, None).unwrap();
|
||||
assert!(list.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_skips_unparseable_files() {
|
||||
let dir = unique_dir();
|
||||
save_to_dir(&dir, &sample("good")).unwrap();
|
||||
std::fs::write(dir.join("garbage.json"), b"{not valid json").unwrap();
|
||||
let list = list_in_dir(&dir, None).unwrap();
|
||||
// Garbage skipped; good survives.
|
||||
assert_eq!(list.len(), 1);
|
||||
assert_eq!(list[0].session_id, "good");
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iso8601_formats_unix_seconds() {
|
||||
// 2024-01-01T00:00:00Z is 1704067200 unix seconds.
|
||||
assert_eq!(
|
||||
unix_to_iso8601(1_704_067_200),
|
||||
Some("2024-01-01T00:00:00Z".into())
|
||||
);
|
||||
assert_eq!(unix_to_iso8601(0), Some("1970-01-01T00:00:00Z".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_id_rejects_path_traversal() {
|
||||
// `../../etc/passwd` — 6 non-alnum chars before "etc"
|
||||
// (`.`, `.`, `/`, `.`, `.`, `/`), one between, none
|
||||
// after, none before nothing. Every disallowed char
|
||||
// collapses to `_`.
|
||||
assert_eq!(sanitize_id("../../etc/passwd"), "______etc_passwd");
|
||||
assert_eq!(sanitize_id("ok-name_42"), "ok-name_42");
|
||||
}
|
||||
}
|
||||
1469
crates/helexa-acp/src/tool_runner.rs
Normal file
1469
crates/helexa-acp/src/tool_runner.rs
Normal file
File diff suppressed because it is too large
Load Diff
300
crates/helexa-acp/src/tools.rs
Normal file
300
crates/helexa-acp/src/tools.rs
Normal file
@@ -0,0 +1,300 @@
|
||||
//! Tool schemas sent to the upstream model on every completion.
|
||||
//!
|
||||
//! These are the OpenAI-function-style declarations the LLM sees in
|
||||
//! `CompletionRequest.tools`; the runtime dispatch happens in
|
||||
//! [`crate::tool_runner`]. Keeping declarations and execution in
|
||||
//! separate modules makes it easy to add a tool without touching the
|
||||
//! runner, and vice versa.
|
||||
//!
|
||||
//! Stage 3 ships five: filesystem read / write / edit, directory
|
||||
//! listing, and `bash`. Image generation, web fetch, MCP-derived
|
||||
//! tools, etc. are out of scope here.
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use crate::provider::ToolSpec;
|
||||
|
||||
pub const READ_FILE: &str = "read_file";
|
||||
pub const WRITE_FILE: &str = "write_file";
|
||||
pub const EDIT_FILE: &str = "edit_file";
|
||||
pub const LIST_DIR: &str = "list_dir";
|
||||
pub const BASH: &str = "bash";
|
||||
|
||||
/// Build the static tool list passed to the model on every prompt.
|
||||
/// Cheap — the JSON Schema fragments are constructed each call but
|
||||
/// the bodies are small constants. If this ever shows up in a
|
||||
/// profile we can `OnceLock` the Vec.
|
||||
pub fn all_tools() -> Vec<ToolSpec> {
|
||||
vec![
|
||||
ToolSpec {
|
||||
name: READ_FILE.to_string(),
|
||||
description: "Read the contents of a text file. Returns the file's text.".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file."
|
||||
},
|
||||
"line": {
|
||||
"type": "integer",
|
||||
"description": "Optional 1-based line number to start reading from.",
|
||||
"minimum": 1
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Optional maximum number of lines to read.",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: WRITE_FILE.to_string(),
|
||||
description: "Write text content to a file, replacing any existing contents. \
|
||||
Creates the file (and parent directories) if needed."
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Full new contents of the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: EDIT_FILE.to_string(),
|
||||
description: "Replace one exact substring in a file with another. \
|
||||
Fails if `old_text` does not appear in the file, or appears more than once. \
|
||||
Use multiple edit_file calls for multiple edits."
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file."
|
||||
},
|
||||
"old_text": {
|
||||
"type": "string",
|
||||
"description": "Exact text fragment to replace. Must be unique within the file."
|
||||
},
|
||||
"new_text": {
|
||||
"type": "string",
|
||||
"description": "Replacement text."
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_text", "new_text"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: LIST_DIR.to_string(),
|
||||
description:
|
||||
"List the entries of a directory. Returns names and a (f|d|l) kind per entry."
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the directory."
|
||||
}
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: BASH.to_string(),
|
||||
description: "Run a shell command via `sh -c`. \
|
||||
Returns combined stdout+stderr and the exit status. \
|
||||
The command runs in the session's working directory unless `cwd` is given."
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Shell command line, evaluated by `sh -c`."
|
||||
},
|
||||
"cwd": {
|
||||
"type": "string",
|
||||
"description": "Optional absolute path to run the command from."
|
||||
}
|
||||
},
|
||||
"required": ["command"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/// Try to infer which tool was intended from the shape of an
|
||||
/// `arguments` object alone. Used by the agent when the model
|
||||
/// emits a `<tool_call>` whose JSON has the right arguments but a
|
||||
/// missing or invalid top-level `name` field — a recurring
|
||||
/// Qwen3.6-27B failure mode.
|
||||
///
|
||||
/// Returns `Some(name)` only when the argument keys uniquely match
|
||||
/// exactly one tool in the catalogue. Ambiguous shapes (`{path}`
|
||||
/// alone could be either [`READ_FILE`] or [`LIST_DIR`]) return
|
||||
/// `None` so the caller surfaces a Failed-card and lets the model
|
||||
/// retry rather than guessing wrong.
|
||||
///
|
||||
/// Inference table (key set → tool):
|
||||
///
|
||||
/// | Keys | Tool |
|
||||
/// |---------------------------------------|--------------|
|
||||
/// | `{command}` or `{command, cwd}` | `bash` |
|
||||
/// | `{path, content}` | `write_file` |
|
||||
/// | `{path, old_text, new_text}` | `edit_file` |
|
||||
/// | `{path}` / `{path, line}` / `{path, line, limit}` | *ambiguous* — None |
|
||||
/// | (anything else) | None |
|
||||
pub fn infer_tool_name(arguments: &serde_json::Value) -> Option<&'static str> {
|
||||
let obj = arguments.as_object()?;
|
||||
let keys: std::collections::HashSet<&str> = obj.keys().map(|s| s.as_str()).collect();
|
||||
|
||||
// `command` is unique to bash. Allow the optional `cwd` arg
|
||||
// alongside but nothing else (any unrecognised keys → bail and
|
||||
// let the model retry rather than misroute).
|
||||
if keys.contains("command") && keys.iter().all(|k| matches!(*k, "command" | "cwd")) {
|
||||
return Some(BASH);
|
||||
}
|
||||
// `content` is unique to write_file.
|
||||
if keys.contains("content") && keys.contains("path") && keys.len() == 2 {
|
||||
return Some(WRITE_FILE);
|
||||
}
|
||||
// `old_text` + `new_text` are unique to edit_file.
|
||||
if keys.contains("old_text")
|
||||
&& keys.contains("new_text")
|
||||
&& keys.contains("path")
|
||||
&& keys.len() == 3
|
||||
{
|
||||
return Some(EDIT_FILE);
|
||||
}
|
||||
// `{path}` / `{path, line}` / `{path, line, limit}` overlap
|
||||
// between read_file (file contents) and list_dir (directory
|
||||
// contents). No safe inference — refuse.
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn all_tools_has_five_named_entries() {
|
||||
let tools = all_tools();
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
|
||||
assert_eq!(
|
||||
names,
|
||||
vec![READ_FILE, WRITE_FILE, EDIT_FILE, LIST_DIR, BASH]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_bash_from_command_only() {
|
||||
let args = serde_json::json!({"command": "ls /tmp"});
|
||||
assert_eq!(infer_tool_name(&args), Some(BASH));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_bash_from_command_and_cwd() {
|
||||
let args = serde_json::json!({"command": "ls", "cwd": "/tmp"});
|
||||
assert_eq!(infer_tool_name(&args), Some(BASH));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_bash_from_mkdir_like_real_failure() {
|
||||
// Lifted verbatim from the agent failure that motivated
|
||||
// this helper (helexa-acp.log @ 10:03:11).
|
||||
let args = serde_json::json!({
|
||||
"command": "mkdir -p /home/grenade/git/beat/beat/doc/plan/{01-discovery,02-segmentation,03-description,04-summary,05-output}"
|
||||
});
|
||||
assert_eq!(infer_tool_name(&args), Some(BASH));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_write_file() {
|
||||
let args = serde_json::json!({"path": "/tmp/x", "content": "hi"});
|
||||
assert_eq!(infer_tool_name(&args), Some(WRITE_FILE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_edit_file() {
|
||||
let args = serde_json::json!({
|
||||
"path": "/tmp/x", "old_text": "a", "new_text": "b"
|
||||
});
|
||||
assert_eq!(infer_tool_name(&args), Some(EDIT_FILE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_ambiguous_path_only() {
|
||||
let args = serde_json::json!({"path": "/tmp/x"});
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_ambiguous_path_with_optionals() {
|
||||
// read_file accepts these optionals; list_dir doesn't —
|
||||
// but Qwen wouldn't reliably emit them either, so we
|
||||
// can't use their presence to disambiguate. Refuse.
|
||||
let args = serde_json::json!({"path": "/tmp/x", "line": 1, "limit": 50});
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_command_with_extra_unknown_keys() {
|
||||
// Defence in depth: an unrecognised key alongside
|
||||
// `command` means we don't really know what tool the
|
||||
// model wanted; refuse rather than guess.
|
||||
let args = serde_json::json!({"command": "ls", "extra": "?"});
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_empty_args() {
|
||||
let args = serde_json::json!({});
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_non_object_args() {
|
||||
let args = serde_json::json!("not an object");
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn every_tool_has_an_object_parameter_schema() {
|
||||
for tool in all_tools() {
|
||||
let ty = tool.parameters.get("type").and_then(|v| v.as_str());
|
||||
assert_eq!(
|
||||
ty,
|
||||
Some("object"),
|
||||
"tool {} parameters.type must be \"object\"",
|
||||
tool.name
|
||||
);
|
||||
assert!(
|
||||
tool.parameters.get("properties").is_some(),
|
||||
"tool {} missing properties",
|
||||
tool.name
|
||||
);
|
||||
assert!(
|
||||
tool.parameters.get("required").is_some(),
|
||||
"tool {} missing required list",
|
||||
tool.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,36 @@ path = "src/lib.rs"
|
||||
name = "neuron"
|
||||
path = "src/main.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
# Enables CUDA acceleration in candle and the cudarc/nccl bindings the
|
||||
# TP worker pool uses. Without this feature, candle compiles for CPU
|
||||
# only, Device::new_cuda calls fall back to CPU, and TP Init/sanity
|
||||
# requests return Error{kind="cuda_feature_not_enabled"}.
|
||||
cuda = [
|
||||
"candle-core/cuda",
|
||||
"candle-core/nccl",
|
||||
"candle-nn/cuda",
|
||||
"candle-transformers/cuda",
|
||||
"dep:cudarc",
|
||||
"dep:half",
|
||||
"dep:cudaforge",
|
||||
]
|
||||
# Use cuDNN for convolution / attention kernels. Requires CUDA.
|
||||
cudnn = [
|
||||
"cuda",
|
||||
"candle-core/cudnn",
|
||||
"candle-nn/cudnn",
|
||||
"candle-transformers/cudnn",
|
||||
]
|
||||
# FlashAttention kernels. Requires CUDA.
|
||||
flash-attn = [
|
||||
"cuda",
|
||||
"candle-transformers/flash-attn",
|
||||
]
|
||||
# Reserved for GPU-only integration tests in later stages.
|
||||
cuda-integration = ["cuda"]
|
||||
|
||||
[dependencies]
|
||||
cortex-core.workspace = true
|
||||
tokio.workspace = true
|
||||
@@ -24,9 +54,54 @@ tracing-subscriber.workspace = true
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
clap.workspace = true
|
||||
thiserror.workspace = true
|
||||
futures.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
figment.workspace = true
|
||||
toml.workspace = true
|
||||
|
||||
# candle for in-process inference. CUDA support is gated behind the
|
||||
# crate's `cuda` feature (default off) so the workspace builds on
|
||||
# non-CUDA hosts and CI runners.
|
||||
candle-core = "0.10.2"
|
||||
candle-nn = "0.10.2"
|
||||
candle-transformers = "0.10.2"
|
||||
# Direct dep on cudarc (matching candle's transitive version) so the
|
||||
# TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on
|
||||
# the `cuda` feature; same toolchain requirement as candle's CUDA path.
|
||||
cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] }
|
||||
# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle
|
||||
# storages. Matches candle-core's pinned major version to avoid double-
|
||||
# compiling the `half` crate at conflicting versions.
|
||||
half = { version = "2.5", optional = true }
|
||||
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
|
||||
hf-hub = { version = "0.4", features = ["tokio"] }
|
||||
# Jinja-compatible template renderer for the model's
|
||||
# `tokenizer_config.json::chat_template`. Hugging Face's chat
|
||||
# templates use a strict subset of Jinja2 that minijinja supports
|
||||
# out of the box. ~80KB compiled; pure Rust, no async surface.
|
||||
# Features: `builtins` for the `is defined` / `default` filters HF
|
||||
# templates use; `json` for `tojson` (some Qwen3 templates emit
|
||||
# tool definitions via tojson); `serde` so we can hand it a
|
||||
# serde_json::Value as the context.
|
||||
minijinja = { version = "2", features = ["builtins", "json", "serde"] }
|
||||
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
|
||||
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
|
||||
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
||||
# without materialising the full tensor on device.
|
||||
safetensors = "0.7"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util"] }
|
||||
reqwest.workspace = true
|
||||
tempfile = "3"
|
||||
|
||||
[build-dependencies]
|
||||
# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a`
|
||||
# under the `cuda` feature. Matches mistralrs's upstream build setup
|
||||
# (their `mistralrs-core/build.rs` uses the same constructor).
|
||||
cudaforge = { version = "0.1", optional = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
# Skip the CUDA path on docs.rs (it lacks nvcc).
|
||||
no-default-features = true
|
||||
|
||||
66
crates/neuron/build.rs
Normal file
66
crates/neuron/build.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
//! Build script: compile the CUDA kernels in `src/cuda/*.cu` into a
|
||||
//! static library and link it under the `cuda` feature.
|
||||
//!
|
||||
//! Patterned on `EricLBuehler/mistral.rs::mistralrs-core/build.rs` —
|
||||
//! same `cudaforge::KernelBuilder` invocation, same NVCC flag set.
|
||||
|
||||
fn main() {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
use std::path::PathBuf;
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rerun-if-changed=src/cuda/");
|
||||
|
||||
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||
|
||||
let mut builder = cudaforge::KernelBuilder::new()
|
||||
.source_glob("src/cuda/*.cu")
|
||||
.out_dir(&build_dir)
|
||||
.arg("-std=c++17")
|
||||
.arg("-O3")
|
||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg("--expt-extended-lambda")
|
||||
.arg("--use_fast_math")
|
||||
.arg("--compiler-options")
|
||||
.arg("-fPIC");
|
||||
|
||||
// sm_<80 doesn't have bf16 intrinsics for WMMA — gate the
|
||||
// bf16-only kernels off in that case. (Mirrors upstream.)
|
||||
if let Some(compute_cap) = builder.get_compute_cap()
|
||||
&& compute_cap < 80
|
||||
{
|
||||
builder = builder.arg("-DNO_BF16_KERNEL");
|
||||
}
|
||||
|
||||
let target = std::env::var("TARGET").unwrap();
|
||||
let out_file = if target.contains("msvc") {
|
||||
build_dir.join("neuroncuda.lib")
|
||||
} else {
|
||||
build_dir.join("libneuroncuda.a")
|
||||
};
|
||||
|
||||
builder
|
||||
.build_lib(out_file)
|
||||
.expect("neuron cuda build failed");
|
||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||
println!("cargo:rustc-link-lib=neuroncuda");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
|
||||
if target.contains("msvc") {
|
||||
// No extra runtime library needed.
|
||||
} else if target.contains("apple")
|
||||
|| target.contains("freebsd")
|
||||
|| target.contains("openbsd")
|
||||
{
|
||||
println!("cargo:rustc-link-lib=dylib=c++");
|
||||
} else if target.contains("android") {
|
||||
println!("cargo:rustc-link-lib=dylib=c++_shared");
|
||||
} else {
|
||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||
}
|
||||
}
|
||||
}
|
||||
93
crates/neuron/src/activation.rs
Normal file
93
crates/neuron/src/activation.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
//! Activation-time pre-warm progress tracking.
|
||||
//!
|
||||
//! Wraps the [`ActivationStatus`] snapshot in an async RwLock so the
|
||||
//! background pre-warm task can update it per-model while the
|
||||
//! `/health` handler reads coherent snapshots. The tracker exists
|
||||
//! because `default_models` loading moved from synchronous-before-bind
|
||||
//! to background-after-bind on 2026-05-26: the listener is up
|
||||
//! immediately, but `/health` now needs to tell callers which of the
|
||||
//! configured defaults are still warming.
|
||||
|
||||
use cortex_core::discovery::{ActivationState, ActivationStatus, PreWarmFailure};
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Shared, async-safe handle to the daemon's activation progress.
|
||||
///
|
||||
/// Construct once in `main` with the configured `default_models` so
|
||||
/// the initial `pending` list matches the spec; clone the `Arc` into
|
||||
/// the `NeuronState` for HTTP handlers and into the spawned pre-warm
|
||||
/// task for updates.
|
||||
pub struct ActivationTracker {
|
||||
inner: RwLock<ActivationStatus>,
|
||||
}
|
||||
|
||||
impl ActivationTracker {
|
||||
/// Build a tracker primed with one entry per spec. An empty spec
|
||||
/// list yields a `Ready` tracker — no point reporting PreWarming
|
||||
/// when there's nothing queued.
|
||||
pub fn new(default_models: &[ModelSpec]) -> Self {
|
||||
let pending: Vec<String> = default_models.iter().map(|s| s.model_id.clone()).collect();
|
||||
let state = if pending.is_empty() {
|
||||
ActivationState::Ready
|
||||
} else {
|
||||
ActivationState::PreWarming
|
||||
};
|
||||
Self {
|
||||
inner: RwLock::new(ActivationStatus {
|
||||
state,
|
||||
pending,
|
||||
in_progress: None,
|
||||
completed: vec![],
|
||||
failed: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a model as in-progress: remove it from `pending`, set as
|
||||
/// `in_progress`. Called immediately before `registry.load_model`.
|
||||
pub async fn start_loading(&self, model_id: &str) {
|
||||
let mut s = self.inner.write().await;
|
||||
s.pending.retain(|m| m != model_id);
|
||||
s.in_progress = Some(model_id.to_string());
|
||||
}
|
||||
|
||||
/// Mark a model as completed: clear `in_progress` (if it matches),
|
||||
/// append to `completed`.
|
||||
pub async fn complete_loading(&self, model_id: &str) {
|
||||
let mut s = self.inner.write().await;
|
||||
if s.in_progress.as_deref() == Some(model_id) {
|
||||
s.in_progress = None;
|
||||
}
|
||||
s.completed.push(model_id.to_string());
|
||||
}
|
||||
|
||||
/// Mark a model as failed: clear `in_progress` (if it matches),
|
||||
/// append a `PreWarmFailure` carrying the rendered error chain.
|
||||
pub async fn fail_loading(&self, model_id: &str, error: &str) {
|
||||
let mut s = self.inner.write().await;
|
||||
if s.in_progress.as_deref() == Some(model_id) {
|
||||
s.in_progress = None;
|
||||
}
|
||||
s.failed.push(PreWarmFailure {
|
||||
model_id: model_id.to_string(),
|
||||
error: error.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
/// Flip the high-level `state` to `Ready` once the pre-warm task
|
||||
/// is done iterating. Pending should be empty by this point; if a
|
||||
/// caller bails early it's a stuck activation and the operator
|
||||
/// will see entries in `pending` even with `state=ready` — that's
|
||||
/// a useful diagnostic, not an inconsistency to scrub.
|
||||
pub async fn mark_ready(&self) {
|
||||
let mut s = self.inner.write().await;
|
||||
s.state = ActivationState::Ready;
|
||||
s.in_progress = None;
|
||||
}
|
||||
|
||||
/// Cheap clone of the current state for the `/health` handler.
|
||||
pub async fn snapshot(&self) -> ActivationStatus {
|
||||
self.inner.read().await.clone()
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,41 @@
|
||||
//! HTTP API handlers for the neuron daemon.
|
||||
|
||||
use crate::activation::ActivationTracker;
|
||||
use crate::harness::HarnessRegistry;
|
||||
use crate::harness::candle::{CandleHarness, InferenceError};
|
||||
use crate::harness::preflight::PreflightError;
|
||||
use crate::health::HealthCache;
|
||||
use crate::wire::{openai_chat, openai_responses};
|
||||
use axum::Router;
|
||||
use axum::extract::{Path, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::routing::{get, post};
|
||||
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use cortex_core::openai::{ChatCompletionRequest, MessageContent};
|
||||
use cortex_core::responses::{ResponsesRequest, ResponsesUsage};
|
||||
use futures::stream::{self, StreamExt};
|
||||
use serde_json::{Value, json};
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
/// Shared state for the neuron HTTP server.
|
||||
pub struct NeuronState {
|
||||
pub discovery: DiscoveryResponse,
|
||||
pub health_cache: Arc<HealthCache>,
|
||||
pub registry: RwLock<HarnessRegistry>,
|
||||
/// Typed handle to the candle harness for inference routes. Cached at
|
||||
/// startup so `/v1/chat/completions` doesn't have to hold the registry
|
||||
/// read lock or perform dyn-Trait dispatch per request.
|
||||
pub candle: Option<Arc<CandleHarness>>,
|
||||
/// Activation-time pre-warm progress. Updated by the background
|
||||
/// `load_default_models` task, read by the `/health` handler.
|
||||
pub activation: Arc<ActivationTracker>,
|
||||
}
|
||||
|
||||
/// Build the neuron API router.
|
||||
@@ -29,6 +47,8 @@ pub fn neuron_routes() -> Router<Arc<NeuronState>> {
|
||||
.route("/models/load", post(load_model))
|
||||
.route("/models/unload", post(unload_model))
|
||||
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/responses", post(responses))
|
||||
}
|
||||
|
||||
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
||||
@@ -36,7 +56,13 @@ async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<Discov
|
||||
}
|
||||
|
||||
async fn health_handler(State(state): State<Arc<NeuronState>>) -> Json<HealthResponse> {
|
||||
Json(state.health_cache.snapshot().await)
|
||||
// HealthCache owns the uptime + per-device readings; the activation
|
||||
// tracker owns the pre-warm progress. We compose the response here
|
||||
// so the cache stays a thin runtime-state cache and doesn't need to
|
||||
// know about activation lifecycle.
|
||||
let mut snapshot = state.health_cache.snapshot().await;
|
||||
snapshot.activation = state.activation.snapshot().await;
|
||||
Json(snapshot)
|
||||
}
|
||||
|
||||
async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse {
|
||||
@@ -45,7 +71,7 @@ async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse
|
||||
Ok(models) => Json(json!(models)).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": e.to_string()})),
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
@@ -58,11 +84,52 @@ async fn load_model(
|
||||
let registry = state.registry.read().await;
|
||||
match registry.load_model(&spec).await {
|
||||
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": e.to_string()})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(e) => {
|
||||
// If the underlying failure is a structured preflight
|
||||
// rejection, surface it as 422 Unprocessable Entity with
|
||||
// the typed JSON body. The kind/model_id/suggestion/etc.
|
||||
// fields let cortex (and operators reading the response
|
||||
// directly) act on the failure without parsing free text.
|
||||
if let Some(pf) = e.downcast_ref::<PreflightError>() {
|
||||
tracing::warn!(
|
||||
model = %spec.model_id,
|
||||
reason = preflight_kind(pf),
|
||||
detail = %pf,
|
||||
"load_model rejected by preflight"
|
||||
);
|
||||
return (
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(json!({ "error": pf })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
// Log the full anyhow chain server-side so journalctl shows
|
||||
// the underlying failure (hf-hub timeout, permission denied,
|
||||
// disk full, etc.) without needing to inspect the HTTP
|
||||
// response body separately.
|
||||
tracing::warn!(
|
||||
model = %spec.model_id,
|
||||
error = %format!("{e:#}"),
|
||||
"load_model failed"
|
||||
);
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Short kebab-case tag for a preflight failure, used as a structured
|
||||
/// log field for journalctl-side filtering. Mirrors the same helper in
|
||||
/// `startup.rs`; duplicated to keep the module surfaces independent.
|
||||
fn preflight_kind(err: &PreflightError) -> &'static str {
|
||||
match err {
|
||||
PreflightError::RepoFetchFailed { .. } => "repo_fetch_failed",
|
||||
PreflightError::EmptyRepo { .. } => "empty_repo",
|
||||
PreflightError::TpRequiresSafetensors { .. } => "tp_requires_safetensors",
|
||||
PreflightError::QuantNotFound { .. } => "quant_not_found",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +151,11 @@ async fn unload_model(
|
||||
let registry = state.registry.read().await;
|
||||
match registry.unload_model(&model_id).await {
|
||||
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
|
||||
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,3 +173,311 @@ async fn model_endpoint(
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenAI-compatible chat completions. Dispatches to streaming SSE when
|
||||
/// `stream: true` is set on the request; otherwise returns a single
|
||||
/// `ChatCompletionResponse`.
|
||||
async fn chat_completions(
|
||||
State(state): State<Arc<NeuronState>>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(req): Json<ChatCompletionRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({"error": "candle harness not enabled on this neuron"})),
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
// Reasoning-content opt-in. Off by default → naïve clients
|
||||
// (Zed's commit-message generator, vanilla OpenAI clients)
|
||||
// never see `<think>` blocks. On when the caller sends
|
||||
// `x-include-thinking: true` (helexa-acp does this so its
|
||||
// own ThinkParser keeps working unchanged).
|
||||
let include_thinking = headers
|
||||
.get("x-include-thinking")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| matches!(s.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
|
||||
.unwrap_or(false);
|
||||
let chat_config = openai_chat::ChatProjectionConfig {
|
||||
include_thinking,
|
||||
reasoning_markers: None, // filled in from the loaded model inside candle
|
||||
};
|
||||
|
||||
if req.stream.unwrap_or(false) {
|
||||
match candle.chat_completion_stream_with(req, chat_config).await {
|
||||
Ok(rx) => {
|
||||
// Each chunk → one SSE `data: {json}` line. After the
|
||||
// channel closes, append the OpenAI [DONE] terminator.
|
||||
let body_stream = ReceiverStream::new(rx).map(|chunk| {
|
||||
let body = serde_json::to_string(&chunk).unwrap_or_default();
|
||||
Ok::<_, Infallible>(Event::default().data(body))
|
||||
});
|
||||
let done_stream =
|
||||
stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
|
||||
Sse::new(body_stream.chain(done_stream))
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
|
||||
"code": "prompt_too_long",
|
||||
"prompt_len": prompt_len,
|
||||
"max": max,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::InsufficientVram {
|
||||
free_mb,
|
||||
required_mb,
|
||||
}) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||
),
|
||||
"code": "insufficient_vram",
|
||||
"free_mb": free_mb,
|
||||
"required_mb": required_mb,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::Other(e)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
} else {
|
||||
match candle.chat_completion(req).await {
|
||||
Ok(resp) => Json(resp).into_response(),
|
||||
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
|
||||
"code": "prompt_too_long",
|
||||
"prompt_len": prompt_len,
|
||||
"max": max,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::InsufficientVram {
|
||||
free_mb,
|
||||
required_mb,
|
||||
}) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||
),
|
||||
"code": "insufficient_vram",
|
||||
"free_mb": free_mb,
|
||||
"required_mb": required_mb,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::Other(e)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenAI Responses API (`POST /v1/responses`). Translates the
|
||||
/// Responses-shaped request into a chat-completions one the candle
|
||||
/// harness already understands, then re-projects the harness's
|
||||
/// event stream into the Responses event family.
|
||||
async fn responses(
|
||||
State(state): State<Arc<NeuronState>>,
|
||||
Json(req): Json<ResponsesRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({"error": "candle harness not enabled on this neuron"})),
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
let stream_requested = req.stream;
|
||||
let model_id = req.model.clone();
|
||||
let response_id = mint_response_id();
|
||||
let message_item_id = mint_message_item_id();
|
||||
|
||||
// Translate Responses → chat completions. The only failure
|
||||
// mode today is `previous_response_id` set, which we reject
|
||||
// with 400 — stateful conversations need a persistence layer
|
||||
// we haven't built.
|
||||
let mut chat_req = match openai_responses::request_to_chat(req) {
|
||||
Ok(r) => r,
|
||||
Err(openai_responses::TranslateError::ChainedConversationNotSupported) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": "previous_response_id is not supported on this neuron",
|
||||
"code": "chained_conversation_not_supported"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
chat_req.stream = Some(stream_requested);
|
||||
|
||||
if stream_requested {
|
||||
match candle
|
||||
.responses_stream(chat_req, response_id, message_item_id)
|
||||
.await
|
||||
{
|
||||
Ok(rx) => {
|
||||
// Each ResponseStreamFrame → one SSE event carrying
|
||||
// both an event name and JSON data. The Responses
|
||||
// API doesn't use a `[DONE]` terminator — clients
|
||||
// see the `response.completed` event as the end of
|
||||
// the stream.
|
||||
let body_stream = ReceiverStream::new(rx).map(|frame| {
|
||||
let body = serde_json::to_string(&frame.data).unwrap_or_else(|_| "{}".into());
|
||||
Ok::<_, Infallible>(Event::default().event(frame.event_name).data(body))
|
||||
});
|
||||
Sse::new(body_stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => inference_error_response(e),
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: drive the existing chat completion path
|
||||
// and translate the result. We don't currently re-tokenise
|
||||
// to compute usage; the harness returns it via the chat
|
||||
// response and we pass it through.
|
||||
match candle.chat_completion(chat_req).await {
|
||||
Ok(chat_resp) => {
|
||||
// Extract the assistant text (chat completions
|
||||
// always emits one choice on the candle path).
|
||||
let text = chat_resp
|
||||
.choices
|
||||
.first()
|
||||
.map(|c| match &c.message.content {
|
||||
MessageContent::Text(t) => t.clone(),
|
||||
MessageContent::Parts(_) => {
|
||||
// Candle output is always text today;
|
||||
// a Parts response would be surprising.
|
||||
// Empty-string fallback is safer than
|
||||
// a panic.
|
||||
String::new()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let finish = chat_resp
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|c| c.finish_reason.as_deref())
|
||||
.map(finish_reason_from_str)
|
||||
.unwrap_or(crate::wire::FinishReason::Stop);
|
||||
let usage = chat_resp.usage.as_ref().map(|u| ResponsesUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
total_tokens: u.prompt_tokens + u.completion_tokens,
|
||||
});
|
||||
let meta = openai_responses::ResponseMeta {
|
||||
response_id: mint_response_id(),
|
||||
created_at: unix_now_secs(),
|
||||
model_id,
|
||||
message_item_id: mint_message_item_id(),
|
||||
};
|
||||
let _ = chat_resp; // make the borrow-checker happy if `text` consumed it
|
||||
let resp = openai_responses::build_response(&meta, text, finish, usage);
|
||||
Json(resp).into_response()
|
||||
}
|
||||
Err(e) => inference_error_response(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn finish_reason_from_str(s: &str) -> crate::wire::FinishReason {
|
||||
use crate::wire::FinishReason;
|
||||
match s {
|
||||
"length" => FinishReason::Length,
|
||||
"tool_calls" => FinishReason::ToolCalls,
|
||||
_ => FinishReason::Stop,
|
||||
}
|
||||
}
|
||||
|
||||
/// Centralised mapping from [`InferenceError`] to an HTTP response.
|
||||
/// Lifted out so the chat-completions and responses handlers stay
|
||||
/// readable and changes to error-code semantics happen in one spot.
|
||||
fn inference_error_response(err: InferenceError) -> axum::response::Response {
|
||||
match err {
|
||||
InferenceError::ModelNotLoaded(id) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||
)
|
||||
.into_response(),
|
||||
InferenceError::PromptTooLong { prompt_len, max } => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
|
||||
"code": "prompt_too_long",
|
||||
"prompt_len": prompt_len,
|
||||
"max": max,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
InferenceError::InsufficientVram {
|
||||
free_mb,
|
||||
required_mb,
|
||||
} => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||
),
|
||||
"code": "insufficient_vram",
|
||||
"free_mb": free_mb,
|
||||
"required_mb": required_mb,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
InferenceError::Other(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mint_response_id() -> String {
|
||||
format!("resp_{:x}", unix_subsec_nanos())
|
||||
}
|
||||
|
||||
fn mint_message_item_id() -> String {
|
||||
format!("msg_{:x}", unix_subsec_nanos())
|
||||
}
|
||||
|
||||
fn unix_now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn unix_subsec_nanos() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos() as u64)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
//! Neuron configuration loaded from neuron.toml.
|
||||
|
||||
use cortex_core::harness::HarnessConfig;
|
||||
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||
use figment::{
|
||||
Figment,
|
||||
providers::{Env, Format, Toml},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NeuronConfig {
|
||||
@@ -14,6 +14,31 @@ pub struct NeuronConfig {
|
||||
pub port: u16,
|
||||
#[serde(default)]
|
||||
pub harnesses: Vec<HarnessConfig>,
|
||||
/// Per-harness configuration. Currently only `candle` is recognised.
|
||||
#[serde(default)]
|
||||
pub harness: HarnessSettings,
|
||||
/// Models to auto-load when the neuron service activates. Each entry
|
||||
/// is loaded sequentially before the HTTP listener binds. A failure
|
||||
/// on any single entry logs a warning and proceeds — broken entries
|
||||
/// don't prevent the rest of the fleet from starting.
|
||||
#[serde(default)]
|
||||
pub default_models: Vec<ModelSpec>,
|
||||
}
|
||||
|
||||
/// Settings for individual harness implementations. Each harness owns
|
||||
/// its own sub-table so users only configure the harnesses they enable.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct HarnessSettings {
|
||||
#[serde(default)]
|
||||
pub candle: CandleHarnessConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct CandleHarnessConfig {
|
||||
/// HuggingFace cache directory for model weights.
|
||||
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
|
||||
#[serde(default)]
|
||||
pub hf_cache: Option<PathBuf>,
|
||||
}
|
||||
|
||||
fn default_port() -> u16 {
|
||||
@@ -35,6 +60,8 @@ impl Default for NeuronConfig {
|
||||
Self {
|
||||
port: 13131,
|
||||
harnesses: vec![],
|
||||
harness: HarnessSettings::default(),
|
||||
default_models: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
84
crates/neuron/src/cuda/ffi.rs
Normal file
84
crates/neuron/src/cuda/ffi.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
//! FFI declarations for the CUDA kernels in `gdn.cu`.
|
||||
//!
|
||||
//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs`
|
||||
//! covering only the Gated DeltaNet kernels we currently use. Other
|
||||
//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective
|
||||
//! scan, etc.) would land here too as we absorb them.
|
||||
//!
|
||||
//! All function declarations are MIT-licensed from upstream and
|
||||
//! unchanged apart from this header.
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
#[allow(dead_code)]
|
||||
unsafe extern "C" {
|
||||
// GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next.
|
||||
pub(crate) fn gated_delta_rule_recurrence(
|
||||
q: *const f32,
|
||||
k: *const f32,
|
||||
v: *const f32,
|
||||
g: *const f32,
|
||||
beta: *const f32,
|
||||
state: *mut f32,
|
||||
output: *mut f32,
|
||||
bh: i32,
|
||||
seq_len: i32,
|
||||
k_dim: i32,
|
||||
v_dim: i32,
|
||||
stream: i64,
|
||||
);
|
||||
|
||||
/// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks).
|
||||
pub(crate) fn chunked_gated_delta_rule_recurrence(
|
||||
q: *const f32,
|
||||
k: *const f32,
|
||||
v: *const f32,
|
||||
g: *const f32,
|
||||
beta: *const f32,
|
||||
state: *mut f32,
|
||||
output: *mut f32,
|
||||
bh: i32,
|
||||
seq_len: i32,
|
||||
k_dim: i32,
|
||||
v_dim: i32,
|
||||
stream: i64,
|
||||
);
|
||||
|
||||
pub(crate) fn causal_conv1d_update(
|
||||
x: *const c_void,
|
||||
weight: *const c_void,
|
||||
conv_state: *mut c_void,
|
||||
output: *mut c_void,
|
||||
batch_size: i32,
|
||||
conv_dim: i32,
|
||||
kernel_size: i32,
|
||||
dtype: i32,
|
||||
stream: i64,
|
||||
);
|
||||
|
||||
pub(crate) fn causal_conv1d_full(
|
||||
x: *const c_void,
|
||||
weight: *const c_void,
|
||||
conv_state_out: *mut c_void,
|
||||
output: *mut c_void,
|
||||
batch_size: i32,
|
||||
conv_dim: i32,
|
||||
seq_len: i32,
|
||||
kernel_size: i32,
|
||||
dtype: i32,
|
||||
stream: i64,
|
||||
);
|
||||
|
||||
pub(crate) fn fused_gdn_gating(
|
||||
b: *const c_void,
|
||||
a: *const c_void,
|
||||
a_log: *const f32,
|
||||
dt_bias: *const f32,
|
||||
beta_out: *mut c_void,
|
||||
g_out: *mut c_void,
|
||||
total_elements: i32,
|
||||
num_heads: i32,
|
||||
dtype: i32,
|
||||
stream: i64,
|
||||
);
|
||||
}
|
||||
711
crates/neuron/src/cuda/gdn.cu
Normal file
711
crates/neuron/src/cuda/gdn.cu
Normal file
@@ -0,0 +1,711 @@
|
||||
// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`).
|
||||
//
|
||||
// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||
// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this
|
||||
// file are limited to this banner; the kernels are unchanged so a
|
||||
// diff against upstream stays minimal.
|
||||
//
|
||||
// Five kernels exposed via `extern "C"` shims at the bottom:
|
||||
// - gated_delta_rule_recurrence (per-token decode)
|
||||
// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill)
|
||||
// - causal_conv1d_update (single-token conv decode)
|
||||
// - causal_conv1d_full (multi-token conv prefill)
|
||||
// - fused_gdn_gating (beta = sigmoid(b);
|
||||
// g = -exp(A_log) * softplus(a + dt_bias))
|
||||
|
||||
#include "cuda_bf16.h"
|
||||
#include "cuda_fp16.h"
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 1: gated_delta_rule_recurrence (optimized)
|
||||
//
|
||||
// V-tiled recurrence with compile-time K dimension for register residency.
|
||||
// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of
|
||||
// state. Shared memory holds k_buf and q_buf (2*BK floats).
|
||||
//
|
||||
// Optimizations over naive version:
|
||||
// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30)
|
||||
// - #pragma unroll on all k-loops -> full ILP
|
||||
// - Fused decay+kv_mem pass and fused state_update+output pass
|
||||
// - __fmaf_rn intrinsics for guaranteed fused multiply-add
|
||||
// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere
|
||||
//
|
||||
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||
// ============================================================================
|
||||
|
||||
// Optimized kernel: BK known at compile time -> registers + full unrolling
|
||||
template <int BK, int BV>
|
||||
__global__ void gated_delta_rule_recurrence_kernel_tiled(
|
||||
const float *__restrict__ q, // [BH, S, K]
|
||||
const float *__restrict__ k, // [BH, S, K]
|
||||
const float *__restrict__ v, // [BH, S, V]
|
||||
const float *__restrict__ g, // [BH, S]
|
||||
const float *__restrict__ beta, // [BH, S]
|
||||
float *__restrict__ state, // [BH, K, V]
|
||||
float *__restrict__ output, // [BH, S, V]
|
||||
int seq_len, int v_dim) {
|
||||
|
||||
const int v_tile = blockIdx.x; // which V-tile
|
||||
const int bh = blockIdx.y; // batch*head index
|
||||
const int tid = threadIdx.x; // thread within tile [0, BV)
|
||||
const int v_idx = v_tile * BV + tid; // global V index
|
||||
|
||||
if (v_idx >= v_dim)
|
||||
return;
|
||||
|
||||
// Pointers for this (batch, head)
|
||||
const float *q_bh = q + bh * seq_len * BK;
|
||||
const float *k_bh = k + bh * seq_len * BK;
|
||||
const float *v_bh = v + bh * seq_len * v_dim;
|
||||
const float *g_bh = g + bh * seq_len;
|
||||
const float *beta_bh = beta + bh * seq_len;
|
||||
float *state_bh = state + bh * BK * v_dim;
|
||||
float *out_bh = output + bh * seq_len * v_dim;
|
||||
|
||||
// Shared memory: k_buf[BK] + q_buf[BK]
|
||||
__shared__ float k_buf[BK];
|
||||
__shared__ float q_buf[BK];
|
||||
|
||||
// Load state column into registers — BK is compile-time, so this is
|
||||
// a true register array (not spilled to local memory)
|
||||
float s[BK];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
s[j] = state_bh[j * v_dim + v_idx];
|
||||
}
|
||||
|
||||
for (int t = 0; t < seq_len; t++) {
|
||||
// Collaboratively load k_t into shared memory
|
||||
// BK / BV loads per thread (e.g. 128/64 = 2)
|
||||
#pragma unroll
|
||||
for (int j = tid; j < BK; j += BV) {
|
||||
k_buf[j] = k_bh[t * BK + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load scalars for this timestep
|
||||
float decay = expf(g_bh[t]);
|
||||
float beta_t = beta_bh[t];
|
||||
float v_t = v_bh[t * v_dim + v_idx];
|
||||
|
||||
// Fused pass 1: decay state + compute kv_mem
|
||||
float kv_mem = 0.0f;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
s[j] *= decay;
|
||||
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||
}
|
||||
|
||||
// Delta rule
|
||||
float delta = (v_t - kv_mem) * beta_t;
|
||||
|
||||
// Collaboratively load q_t into shared memory
|
||||
#pragma unroll
|
||||
for (int j = tid; j < BK; j += BV) {
|
||||
q_buf[j] = q_bh[t * BK + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Fused pass 2: update state + compute output
|
||||
float y_t = 0.0f;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||
}
|
||||
|
||||
out_bh[t * v_dim + v_idx] = y_t;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Write state back
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
state_bh[j * v_dim + v_idx] = s[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback kernel: runtime k_dim, still V-tiled for occupancy
|
||||
template <int BV, int MAX_K>
|
||||
__global__ void gated_delta_rule_recurrence_kernel_fallback(
|
||||
const float *__restrict__ q, const float *__restrict__ k,
|
||||
const float *__restrict__ v, const float *__restrict__ g,
|
||||
const float *__restrict__ beta, float *__restrict__ state,
|
||||
float *__restrict__ output, int seq_len, int k_dim, int v_dim) {
|
||||
|
||||
const int v_tile = blockIdx.x;
|
||||
const int bh = blockIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int v_idx = v_tile * BV + tid;
|
||||
|
||||
if (v_idx >= v_dim)
|
||||
return;
|
||||
|
||||
const float *q_bh = q + bh * seq_len * k_dim;
|
||||
const float *k_bh = k + bh * seq_len * k_dim;
|
||||
const float *v_bh = v + bh * seq_len * v_dim;
|
||||
const float *g_bh = g + bh * seq_len;
|
||||
const float *beta_bh = beta + bh * seq_len;
|
||||
float *state_bh = state + bh * k_dim * v_dim;
|
||||
float *out_bh = output + bh * seq_len * v_dim;
|
||||
|
||||
extern __shared__ float shared[];
|
||||
float *k_buf = shared;
|
||||
float *q_buf = shared + k_dim;
|
||||
|
||||
float s[MAX_K];
|
||||
for (int j = 0; j < k_dim; j++) {
|
||||
s[j] = state_bh[j * v_dim + v_idx];
|
||||
}
|
||||
|
||||
for (int t = 0; t < seq_len; t++) {
|
||||
for (int j = tid; j < k_dim; j += BV) {
|
||||
k_buf[j] = k_bh[t * k_dim + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float decay = expf(g_bh[t]);
|
||||
float beta_t = beta_bh[t];
|
||||
float v_t = v_bh[t * v_dim + v_idx];
|
||||
|
||||
float kv_mem = 0.0f;
|
||||
for (int j = 0; j < k_dim; j++) {
|
||||
s[j] *= decay;
|
||||
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||
}
|
||||
|
||||
float delta = (v_t - kv_mem) * beta_t;
|
||||
|
||||
for (int j = tid; j < k_dim; j += BV) {
|
||||
q_buf[j] = q_bh[t * k_dim + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float y_t = 0.0f;
|
||||
for (int j = 0; j < k_dim; j++) {
|
||||
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||
}
|
||||
|
||||
out_bh[t * v_dim + v_idx] = y_t;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
for (int j = 0; j < k_dim; j++) {
|
||||
state_bh[j * v_dim + v_idx] = s[j];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void gated_delta_rule_recurrence(const float *q, const float *k,
|
||||
const float *v, const float *g,
|
||||
const float *beta, float *state,
|
||||
float *output, int bh, int seq_len,
|
||||
int k_dim, int v_dim,
|
||||
int64_t stream) {
|
||||
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
|
||||
if (k_dim == 128) {
|
||||
// Fast path for Qwen3-Next (k_dim=128)
|
||||
constexpr int BK = 128;
|
||||
constexpr int BV = 64;
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||
v_dim);
|
||||
} else if (k_dim == 64) {
|
||||
// Fast path for models with k_dim=64
|
||||
constexpr int BK = 64;
|
||||
constexpr int BV = 64;
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||
v_dim);
|
||||
} else {
|
||||
// Fallback for other k_dim values (runtime loop, still V-tiled)
|
||||
constexpr int BV = 64;
|
||||
constexpr int MAX_K = 256;
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
size_t smem = 2 * k_dim * sizeof(float);
|
||||
gated_delta_rule_recurrence_kernel_fallback<BV, MAX_K>
|
||||
<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||
seq_len, k_dim, v_dim);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization)
|
||||
//
|
||||
// Processes prefill tokens in BT-token chunks instead of one at a time.
|
||||
// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation,
|
||||
// forward substitution (triangular solve), output computation, and state
|
||||
// update.
|
||||
//
|
||||
// Same thread model as Kernel 1: one block per (v_tile, batch*head),
|
||||
// one thread per V-column. Each thread owns BK registers of state.
|
||||
//
|
||||
// Shared memory holds:
|
||||
// k_chunk[BT * BK] -- key vectors for current chunk
|
||||
// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix
|
||||
// gcum[BT] -- cumulative sum of g within chunk
|
||||
// beta_s[BT] -- beta values for chunk
|
||||
// q_buf[BK] -- q vector (loaded one row at a time)
|
||||
//
|
||||
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||
// ============================================================================
|
||||
|
||||
template <int BT, int BK, int BV>
|
||||
__global__ void
|
||||
chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K]
|
||||
const float *__restrict__ k, // [BH, S, K]
|
||||
const float *__restrict__ v, // [BH, S, V]
|
||||
const float *__restrict__ g, // [BH, S]
|
||||
const float *__restrict__ beta, // [BH, S]
|
||||
float *__restrict__ state, // [BH, K, V]
|
||||
float *__restrict__ output, // [BH, S, V]
|
||||
int seq_len, int v_dim) {
|
||||
|
||||
const int v_tile = blockIdx.x;
|
||||
const int bh = blockIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int v_idx = v_tile * BV + tid;
|
||||
|
||||
if (v_idx >= v_dim)
|
||||
return;
|
||||
|
||||
const int num_chunks = (seq_len + BT - 1) / BT;
|
||||
|
||||
// Pointers for this (batch, head)
|
||||
const float *q_bh = q + bh * seq_len * BK;
|
||||
const float *k_bh = k + bh * seq_len * BK;
|
||||
const float *v_bh = v + bh * seq_len * v_dim;
|
||||
const float *g_bh = g + bh * seq_len;
|
||||
const float *beta_bh = beta + bh * seq_len;
|
||||
float *state_bh = state + bh * BK * v_dim;
|
||||
float *out_bh = output + bh * seq_len * v_dim;
|
||||
|
||||
// Dynamic shared memory layout
|
||||
extern __shared__ float smem[];
|
||||
float *k_chunk = smem; // [BT * BK]
|
||||
float *kk_dot = smem + BT * BK; // [BT * BT]
|
||||
float *gcum = smem + BT * BK + BT * BT; // [BT]
|
||||
float *beta_s = gcum + BT; // [BT]
|
||||
float *q_buf = beta_s + BT; // [BK]
|
||||
|
||||
// Load state column into registers
|
||||
float s[BK];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
s[j] = state_bh[j * v_dim + v_idx];
|
||||
}
|
||||
|
||||
// Per-thread register array for corrected deltas
|
||||
float delta[BT];
|
||||
|
||||
for (int c = 0; c < num_chunks; c++) {
|
||||
const int chunk_start = c * BT;
|
||||
const int chunk_len = min(BT, seq_len - chunk_start);
|
||||
|
||||
// === Phase 1: Cooperative load of k, beta, g into shared memory ===
|
||||
for (int t = 0; t < chunk_len; t++) {
|
||||
for (int j = tid; j < BK; j += BV) {
|
||||
k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j];
|
||||
}
|
||||
}
|
||||
if (tid < chunk_len) {
|
||||
beta_s[tid] = beta_bh[chunk_start + tid];
|
||||
gcum[tid] = g_bh[chunk_start + tid];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// === Phase 1b: Parallel prefix sum of g (Hillis-Steele) ===
|
||||
for (int stride = 1; stride < BT; stride <<= 1) {
|
||||
float prev = 0.0f;
|
||||
if (tid < chunk_len && (int)tid >= stride)
|
||||
prev = gcum[tid - stride];
|
||||
__syncthreads();
|
||||
if (tid < chunk_len && (int)tid >= stride)
|
||||
gcum[tid] += prev;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i ===
|
||||
// Only lower-triangular entries needed (strictly lower)
|
||||
for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) {
|
||||
int i = idx / chunk_len;
|
||||
int j = idx % chunk_len;
|
||||
if (j < i) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < BK; d++) {
|
||||
dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot);
|
||||
}
|
||||
kk_dot[i * BT + j] = dot;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// === Phase 3: Forward substitution (per V-column, in registers) ===
|
||||
// Computes corrected delta values via triangular solve
|
||||
for (int i = 0; i < chunk_len; i++) {
|
||||
float v_i = v_bh[(chunk_start + i) * v_dim + v_idx];
|
||||
float decay_i = expf(gcum[i]);
|
||||
float beta_i = beta_s[i];
|
||||
|
||||
// Inter-chunk contribution: state @ k[i] with decay
|
||||
float kv_mem = 0.0f;
|
||||
#pragma unroll
|
||||
for (int d = 0; d < BK; d++) {
|
||||
kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem);
|
||||
}
|
||||
|
||||
float rhs = beta_i * (v_i - kv_mem);
|
||||
|
||||
// Subtract lower-triangular contributions (intra-chunk)
|
||||
for (int j = 0; j < i; j++) {
|
||||
float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]);
|
||||
rhs -= a_ij * delta[j];
|
||||
}
|
||||
delta[i] = rhs;
|
||||
}
|
||||
|
||||
// === Phase 4: Output computation (per V-column) ===
|
||||
for (int i = 0; i < chunk_len; i++) {
|
||||
// Cooperatively load q[i] into shared
|
||||
for (int j = tid; j < BK; j += BV) {
|
||||
q_buf[j] = q_bh[(chunk_start + i) * BK + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float decay_i = expf(gcum[i]);
|
||||
|
||||
// Inter-chunk contribution: q[i] @ (state * decay)
|
||||
float o_val = 0.0f;
|
||||
#pragma unroll
|
||||
for (int d = 0; d < BK; d++) {
|
||||
o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val);
|
||||
}
|
||||
|
||||
// Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] *
|
||||
// exp(gcum[i] - gcum[j])
|
||||
for (int j = 0; j <= i; j++) {
|
||||
float qk_dot = 0.0f;
|
||||
for (int d = 0; d < BK; d++) {
|
||||
qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot);
|
||||
}
|
||||
o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]);
|
||||
}
|
||||
|
||||
out_bh[(chunk_start + i) * v_dim + v_idx] = o_val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// === Phase 5: State update for next chunk ===
|
||||
float g_total = gcum[chunk_len - 1];
|
||||
#pragma unroll
|
||||
for (int d = 0; d < BK; d++) {
|
||||
float s_new = s[d] * expf(g_total);
|
||||
for (int t = 0; t < chunk_len; t++) {
|
||||
s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]);
|
||||
}
|
||||
s[d] = s_new;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Write final state back
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
state_bh[j * v_dim + v_idx] = s[j];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void chunked_gated_delta_rule_recurrence(
|
||||
const float *q, const float *k, const float *v, const float *g,
|
||||
const float *beta, float *state, float *output, int bh, int seq_len,
|
||||
int k_dim, int v_dim, int64_t stream) {
|
||||
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
|
||||
if (k_dim == 128) {
|
||||
constexpr int BT = 64;
|
||||
constexpr int BK = 128;
|
||||
constexpr int BV = 64;
|
||||
// Shared memory: BT*BK + BT*BT + BT + BT + BK floats
|
||||
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||
|
||||
// Request extended shared memory
|
||||
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem);
|
||||
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||
seq_len, v_dim);
|
||||
} else if (k_dim == 64) {
|
||||
constexpr int BT = 64;
|
||||
constexpr int BK = 64;
|
||||
constexpr int BV = 64;
|
||||
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||
|
||||
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem);
|
||||
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||
seq_len, v_dim);
|
||||
} else {
|
||||
// Fallback: use the sequential kernel for unsupported k_dim
|
||||
gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len,
|
||||
k_dim, v_dim, stream);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 2a: causal_conv1d_update (decode path, single step)
|
||||
//
|
||||
// Each thread handles one channel: shift conv_state left by 1,
|
||||
// insert new value, dot product with weight, apply SiLU.
|
||||
//
|
||||
// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||
// conv_state: [B, conv_dim, kernel_size] (in/out)
|
||||
// output: [B, conv_dim, 1]
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
__global__ void causal_conv1d_update_kernel(
|
||||
const T *__restrict__ x, // [B, conv_dim, 1]
|
||||
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||
T *__restrict__ conv_state, // [B, conv_dim, kernel_size]
|
||||
T *__restrict__ output, // [B, conv_dim, 1]
|
||||
int batch_size, int conv_dim, int kernel_size) {
|
||||
|
||||
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
|
||||
if (ch >= conv_dim || b >= batch_size)
|
||||
return;
|
||||
|
||||
// Pointer to this batch/channel's conv state
|
||||
T *cs = conv_state + (b * conv_dim + ch) * kernel_size;
|
||||
const T *w = weight + ch * kernel_size;
|
||||
|
||||
// Shift state left by 1
|
||||
for (int i = 0; i < kernel_size - 1; i++) {
|
||||
cs[i] = cs[i + 1];
|
||||
}
|
||||
// Insert new value
|
||||
cs[kernel_size - 1] = x[b * conv_dim + ch];
|
||||
|
||||
// Dot product with weight
|
||||
float acc = 0.0f;
|
||||
for (int i = 0; i < kernel_size; i++) {
|
||||
acc += (float)cs[i] * (float)w[i];
|
||||
}
|
||||
|
||||
// SiLU activation: x * sigmoid(x)
|
||||
float sig = 1.0f / (1.0f + expf(-acc));
|
||||
float result = acc * sig;
|
||||
|
||||
output[b * conv_dim + ch] = (T)result;
|
||||
}
|
||||
|
||||
extern "C" void causal_conv1d_update(const void *x, const void *weight,
|
||||
void *conv_state, void *output,
|
||||
int batch_size, int conv_dim,
|
||||
int kernel_size, int dtype,
|
||||
int64_t stream) {
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
dim3 block(256);
|
||||
dim3 grid((conv_dim + 255) / 256, batch_size);
|
||||
|
||||
if (dtype == 0) {
|
||||
// f16
|
||||
causal_conv1d_update_kernel<__half><<<grid, block, 0, custream>>>(
|
||||
(const __half *)x, (const __half *)weight, (__half *)conv_state,
|
||||
(__half *)output, batch_size, conv_dim, kernel_size);
|
||||
} else {
|
||||
// bf16
|
||||
causal_conv1d_update_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||
(__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size,
|
||||
conv_dim, kernel_size);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 2b: causal_conv1d_full (prefill path)
|
||||
//
|
||||
// Each thread handles one (channel, position): causal window with
|
||||
// zero-padding, dot product with weight, SiLU.
|
||||
// A second pass writes the conv_state from the last kernel_size positions.
|
||||
//
|
||||
// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||
// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S]
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
__global__ void causal_conv1d_full_kernel(
|
||||
const T *__restrict__ x, // [B, conv_dim, S]
|
||||
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||
T *__restrict__ output, // [B, conv_dim, S]
|
||||
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||
|
||||
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int pos = blockIdx.y;
|
||||
const int b = blockIdx.z;
|
||||
|
||||
if (ch >= conv_dim || pos >= seq_len || b >= batch_size)
|
||||
return;
|
||||
|
||||
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||
const T *w = weight + ch * kernel_size;
|
||||
|
||||
// Causal convolution: sum over kernel_size window ending at pos
|
||||
float acc = 0.0f;
|
||||
for (int i = 0; i < kernel_size; i++) {
|
||||
int src_pos = pos - (kernel_size - 1) + i;
|
||||
float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f;
|
||||
acc += x_val * (float)w[i];
|
||||
}
|
||||
|
||||
// SiLU
|
||||
float sig = 1.0f / (1.0f + expf(-acc));
|
||||
float result = acc * sig;
|
||||
|
||||
output[(b * conv_dim + ch) * seq_len + pos] = (T)result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void save_conv_state_kernel(
|
||||
const T *__restrict__ x, // [B, conv_dim, S]
|
||||
T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size]
|
||||
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||
|
||||
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
|
||||
if (ch >= conv_dim || b >= batch_size)
|
||||
return;
|
||||
|
||||
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||
T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size;
|
||||
|
||||
// Save last kernel_size positions (zero-pad if seq_len < kernel_size)
|
||||
int pad = kernel_size - seq_len;
|
||||
for (int i = 0; i < kernel_size; i++) {
|
||||
if (i < pad) {
|
||||
cs[i] = (T)0.0f;
|
||||
} else {
|
||||
cs[i] = x_bch[seq_len - kernel_size + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void causal_conv1d_full(const void *x, const void *weight,
|
||||
void *conv_state_out, void *output,
|
||||
int batch_size, int conv_dim, int seq_len,
|
||||
int kernel_size, int dtype, int64_t stream) {
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
|
||||
// Main convolution kernel
|
||||
dim3 block(256);
|
||||
dim3 grid((conv_dim + 255) / 256, seq_len, batch_size);
|
||||
|
||||
if (dtype == 0) {
|
||||
causal_conv1d_full_kernel<__half><<<grid, block, 0, custream>>>(
|
||||
(const __half *)x, (const __half *)weight, (__half *)output, batch_size,
|
||||
conv_dim, seq_len, kernel_size);
|
||||
// Save conv state
|
||||
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||
save_conv_state_kernel<__half><<<grid2, block, 0, custream>>>(
|
||||
(const __half *)x, (__half *)conv_state_out, batch_size, conv_dim,
|
||||
seq_len, kernel_size);
|
||||
} else {
|
||||
causal_conv1d_full_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||
(__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size);
|
||||
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||
save_conv_state_kernel<__nv_bfloat16><<<grid2, block, 0, custream>>>(
|
||||
(const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size,
|
||||
conv_dim, seq_len, kernel_size);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 3: fused_gdn_gating
|
||||
//
|
||||
// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||
// a_log and dt_bias are per-head (broadcast over batch*seq).
|
||||
//
|
||||
// b, a: [total] a_log, dt_bias: [num_heads]
|
||||
// beta_out, g_out: [total]
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
fused_gdn_gating_kernel(const T *__restrict__ b, // [total]
|
||||
const T *__restrict__ a, // [total]
|
||||
const float *__restrict__ a_log, // [num_heads]
|
||||
const float *__restrict__ dt_bias, // [num_heads]
|
||||
T *__restrict__ beta_out, // [total]
|
||||
T *__restrict__ g_out, // [total]
|
||||
int total_elements, int num_heads) {
|
||||
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total_elements)
|
||||
return;
|
||||
|
||||
// Head index: elements are laid out as [..., num_heads]
|
||||
int head_idx = idx % num_heads;
|
||||
|
||||
// beta = sigmoid(b)
|
||||
float b_val = (float)b[idx];
|
||||
float beta = 1.0f / (1.0f + expf(-b_val));
|
||||
|
||||
// g = -exp(a_log) * softplus(a + dt_bias)
|
||||
float a_val = (float)a[idx];
|
||||
float a_log_val = a_log[head_idx];
|
||||
float dt_bias_val = dt_bias[head_idx];
|
||||
|
||||
float sp_input = a_val + dt_bias_val;
|
||||
float softplus_val = logf(1.0f + expf(sp_input));
|
||||
float g_val = -expf(a_log_val) * softplus_val;
|
||||
|
||||
beta_out[idx] = (T)beta;
|
||||
g_out[idx] = (T)g_val;
|
||||
}
|
||||
|
||||
extern "C" void fused_gdn_gating(const void *b, const void *a,
|
||||
const float *a_log, const float *dt_bias,
|
||||
void *beta_out, void *g_out,
|
||||
int total_elements, int num_heads, int dtype,
|
||||
int64_t stream) {
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
dim3 block(256);
|
||||
dim3 grid((total_elements + 255) / 256);
|
||||
|
||||
if (dtype == 0) {
|
||||
fused_gdn_gating_kernel<__half><<<grid, block, 0, custream>>>(
|
||||
(const __half *)b, (const __half *)a, a_log, dt_bias,
|
||||
(__half *)beta_out, (__half *)g_out, total_elements, num_heads);
|
||||
} else {
|
||||
fused_gdn_gating_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||
(const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias,
|
||||
(__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements,
|
||||
num_heads);
|
||||
}
|
||||
}
|
||||
486
crates/neuron/src/cuda/gdn.rs
Normal file
486
crates/neuron/src/cuda/gdn.rs
Normal file
@@ -0,0 +1,486 @@
|
||||
//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`.
|
||||
//!
|
||||
//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||
//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in
|
||||
//! this file are this header comment — the FFI path module name is
|
||||
//! `crate::cuda::ffi`, identical to upstream's layout.
|
||||
|
||||
#![allow(clippy::cast_possible_truncation)]
|
||||
|
||||
use candle_core::{Result, Tensor};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use candle_core::DType;
|
||||
|
||||
/// CUDA-accelerated gated delta rule recurrence.
|
||||
///
|
||||
/// Inputs (all contiguous, f32):
|
||||
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||
/// state: [BH, K, V] (mutated in place)
|
||||
///
|
||||
/// Returns: output [BH, S, V]
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn gated_delta_rule_recurrence_cuda(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
g: &Tensor,
|
||||
beta: &Tensor,
|
||||
state: &mut Tensor,
|
||||
) -> Result<Tensor> {
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle_core as candle;
|
||||
|
||||
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||
let v_dim = v.dim(2)?;
|
||||
|
||||
let dev = q.device().as_cuda_device()?;
|
||||
|
||||
let (q_s, q_l) = q.storage_and_layout();
|
||||
let q_s = match &*q_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("q must be a cuda tensor"),
|
||||
};
|
||||
let q_offset = q_l.start_offset();
|
||||
|
||||
let (k_s, k_l) = k.storage_and_layout();
|
||||
let k_s = match &*k_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("k must be a cuda tensor"),
|
||||
};
|
||||
let k_offset = k_l.start_offset();
|
||||
|
||||
let (v_s, v_l) = v.storage_and_layout();
|
||||
let v_s = match &*v_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("v must be a cuda tensor"),
|
||||
};
|
||||
let v_offset = v_l.start_offset();
|
||||
|
||||
let (g_s, g_l) = g.storage_and_layout();
|
||||
let g_s = match &*g_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("g must be a cuda tensor"),
|
||||
};
|
||||
let g_offset = g_l.start_offset();
|
||||
|
||||
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||
let beta_s = match &*beta_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("beta must be a cuda tensor"),
|
||||
};
|
||||
let beta_offset = beta_l.start_offset();
|
||||
|
||||
let (state_s, state_l) = state.storage_and_layout();
|
||||
let state_s = match &*state_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("state must be a cuda tensor"),
|
||||
};
|
||||
let state_offset = state_l.start_offset();
|
||||
|
||||
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||
|
||||
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::gated_delta_rule_recurrence(
|
||||
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||
bh as i32,
|
||||
seq_len as i32,
|
||||
k_dim as i32,
|
||||
v_dim as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
// The kernel wrote state in-place via the raw pointer; rewrap
|
||||
// (state tensor's underlying CudaSlice was modified directly)
|
||||
|
||||
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||
Ok(Tensor::from((
|
||||
candle::Storage::Cuda(output_storage),
|
||||
(bh, seq_len, v_dim),
|
||||
)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(unused)]
|
||||
pub fn gated_delta_rule_recurrence_cuda(
|
||||
_q: &Tensor,
|
||||
_k: &Tensor,
|
||||
_v: &Tensor,
|
||||
_g: &Tensor,
|
||||
_beta: &Tensor,
|
||||
_state: &mut Tensor,
|
||||
) -> Result<Tensor> {
|
||||
candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||
}
|
||||
|
||||
/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization).
|
||||
///
|
||||
/// Processes prefill tokens in 64-token chunks instead of one at a time.
|
||||
/// Same interface as `gated_delta_rule_recurrence_cuda`.
|
||||
///
|
||||
/// Inputs (all contiguous, f32):
|
||||
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||
/// state: [BH, K, V] (mutated in place)
|
||||
///
|
||||
/// Returns: output [BH, S, V]
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
g: &Tensor,
|
||||
beta: &Tensor,
|
||||
state: &mut Tensor,
|
||||
) -> Result<Tensor> {
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle_core as candle;
|
||||
|
||||
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||
let v_dim = v.dim(2)?;
|
||||
|
||||
let dev = q.device().as_cuda_device()?;
|
||||
|
||||
let (q_s, q_l) = q.storage_and_layout();
|
||||
let q_s = match &*q_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("q must be a cuda tensor"),
|
||||
};
|
||||
let q_offset = q_l.start_offset();
|
||||
|
||||
let (k_s, k_l) = k.storage_and_layout();
|
||||
let k_s = match &*k_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("k must be a cuda tensor"),
|
||||
};
|
||||
let k_offset = k_l.start_offset();
|
||||
|
||||
let (v_s, v_l) = v.storage_and_layout();
|
||||
let v_s = match &*v_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("v must be a cuda tensor"),
|
||||
};
|
||||
let v_offset = v_l.start_offset();
|
||||
|
||||
let (g_s, g_l) = g.storage_and_layout();
|
||||
let g_s = match &*g_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("g must be a cuda tensor"),
|
||||
};
|
||||
let g_offset = g_l.start_offset();
|
||||
|
||||
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||
let beta_s = match &*beta_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("beta must be a cuda tensor"),
|
||||
};
|
||||
let beta_offset = beta_l.start_offset();
|
||||
|
||||
let (state_s, state_l) = state.storage_and_layout();
|
||||
let state_s = match &*state_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("state must be a cuda tensor"),
|
||||
};
|
||||
let state_offset = state_l.start_offset();
|
||||
|
||||
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||
|
||||
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::chunked_gated_delta_rule_recurrence(
|
||||
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||
bh as i32,
|
||||
seq_len as i32,
|
||||
k_dim as i32,
|
||||
v_dim as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||
Ok(Tensor::from((
|
||||
candle::Storage::Cuda(output_storage),
|
||||
(bh, seq_len, v_dim),
|
||||
)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(unused)]
|
||||
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||
_q: &Tensor,
|
||||
_k: &Tensor,
|
||||
_v: &Tensor,
|
||||
_g: &Tensor,
|
||||
_beta: &Tensor,
|
||||
_state: &mut Tensor,
|
||||
) -> Result<Tensor> {
|
||||
candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||
}
|
||||
|
||||
/// CUDA-accelerated causal conv1d (both update and full paths).
|
||||
///
|
||||
/// For update (is_update=true):
|
||||
/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||
/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update)
|
||||
/// Returns: (output [B, conv_dim, 1], updated conv_state)
|
||||
///
|
||||
/// For full (is_update=false):
|
||||
/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||
/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size])
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn causal_conv1d_cuda(
|
||||
x: &Tensor,
|
||||
weight: &Tensor,
|
||||
conv_state: &Tensor,
|
||||
kernel_size: usize,
|
||||
is_update: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle_core as candle;
|
||||
use core::ffi::c_void;
|
||||
fn cuda_fwd<
|
||||
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||
>(
|
||||
x: &Tensor,
|
||||
weight: &Tensor,
|
||||
conv_state: &Tensor,
|
||||
kernel_size: usize,
|
||||
is_update: bool,
|
||||
dtype_code: i32,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let dev = x.device().as_cuda_device()?;
|
||||
let (batch_size, conv_dim, seq_len) = x.dims3()?;
|
||||
|
||||
let (x_s, x_l) = x.storage_and_layout();
|
||||
let x_s = match &*x_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("x must be a cuda tensor"),
|
||||
};
|
||||
let x_offset = x_l.start_offset();
|
||||
|
||||
let (w_s, w_l) = weight.storage_and_layout();
|
||||
let w_s = match &*w_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("weight must be a cuda tensor"),
|
||||
};
|
||||
let w_offset = w_l.start_offset();
|
||||
|
||||
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||
|
||||
if is_update {
|
||||
// Clone conv_state so the kernel can mutate it in place
|
||||
let conv_state_new = conv_state.clone();
|
||||
|
||||
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim) }?;
|
||||
|
||||
// Scope the borrow of conv_state_new so we can move it later
|
||||
{
|
||||
let (cs_s, cs_l) = conv_state_new.storage_and_layout();
|
||||
let cs_s = match &*cs_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("conv_state must be a cuda tensor"),
|
||||
};
|
||||
let cs_offset = cs_l.start_offset();
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::causal_conv1d_update(
|
||||
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||
cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void,
|
||||
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||
batch_size as i32,
|
||||
conv_dim as i32,
|
||||
kernel_size as i32,
|
||||
dtype_code,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||
let output = Tensor::from((
|
||||
candle::Storage::Cuda(output_storage),
|
||||
(batch_size, conv_dim, 1usize),
|
||||
));
|
||||
|
||||
Ok((output, conv_state_new))
|
||||
} else {
|
||||
// Full path: allocate new conv_state and output
|
||||
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * seq_len) }?;
|
||||
let cs_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * kernel_size) }?;
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::causal_conv1d_full(
|
||||
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||
cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void,
|
||||
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||
batch_size as i32,
|
||||
conv_dim as i32,
|
||||
seq_len as i32,
|
||||
kernel_size as i32,
|
||||
dtype_code,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||
let output = Tensor::from((
|
||||
candle::Storage::Cuda(output_storage),
|
||||
(batch_size, conv_dim, seq_len),
|
||||
));
|
||||
|
||||
let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone());
|
||||
let new_conv_state = Tensor::from((
|
||||
candle::Storage::Cuda(cs_storage),
|
||||
(batch_size, conv_dim, kernel_size),
|
||||
));
|
||||
|
||||
Ok((output, new_conv_state))
|
||||
}
|
||||
}
|
||||
|
||||
match x.dtype() {
|
||||
DType::F16 => cuda_fwd::<half::f16>(x, weight, conv_state, kernel_size, is_update, 0),
|
||||
DType::BF16 => cuda_fwd::<half::bf16>(x, weight, conv_state, kernel_size, is_update, 1),
|
||||
other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(unused)]
|
||||
pub fn causal_conv1d_cuda(
|
||||
_x: &Tensor,
|
||||
_weight: &Tensor,
|
||||
_conv_state: &Tensor,
|
||||
_kernel_size: usize,
|
||||
_is_update: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
candle_core::bail!("causal_conv1d_cuda requires the cuda feature")
|
||||
}
|
||||
|
||||
/// CUDA-accelerated fused GDN gating computation.
|
||||
///
|
||||
/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||
///
|
||||
/// b, a: [total_elements] in f16/bf16
|
||||
/// a_log, dt_bias: [num_heads] in f32
|
||||
///
|
||||
/// Returns: (beta, g) in original dtype
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn fused_gdn_gating_cuda(
|
||||
b: &Tensor,
|
||||
a: &Tensor,
|
||||
a_log: &Tensor,
|
||||
dt_bias: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle_core as candle;
|
||||
use core::ffi::c_void;
|
||||
|
||||
fn cuda_fwd<
|
||||
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||
>(
|
||||
b: &Tensor,
|
||||
a: &Tensor,
|
||||
a_log: &Tensor,
|
||||
dt_bias: &Tensor,
|
||||
dtype_code: i32,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let total_elements = b.elem_count();
|
||||
let num_heads = a_log.elem_count();
|
||||
let shape = b.shape().clone();
|
||||
let dev = b.device().as_cuda_device()?;
|
||||
|
||||
let (b_s, b_l) = b.storage_and_layout();
|
||||
let b_s = match &*b_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("b must be a cuda tensor"),
|
||||
};
|
||||
let b_offset = b_l.start_offset();
|
||||
|
||||
let (a_s, a_l) = a.storage_and_layout();
|
||||
let a_s = match &*a_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("a must be a cuda tensor"),
|
||||
};
|
||||
let a_offset = a_l.start_offset();
|
||||
|
||||
let (alog_s, alog_l) = a_log.storage_and_layout();
|
||||
let alog_s = match &*alog_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("a_log must be a cuda tensor"),
|
||||
};
|
||||
let alog_offset = alog_l.start_offset();
|
||||
|
||||
let (dtb_s, dtb_l) = dt_bias.storage_and_layout();
|
||||
let dtb_s = match &*dtb_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("dt_bias must be a cuda tensor"),
|
||||
};
|
||||
let dtb_offset = dtb_l.start_offset();
|
||||
|
||||
let beta_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||
let g_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||
|
||||
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::fused_gdn_gating(
|
||||
b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void,
|
||||
a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void,
|
||||
alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32,
|
||||
dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32,
|
||||
beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void,
|
||||
g_buf.device_ptr(g_buf.stream()).0 as *mut c_void,
|
||||
total_elements as i32,
|
||||
num_heads as i32,
|
||||
dtype_code,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone());
|
||||
let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone()));
|
||||
|
||||
let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone());
|
||||
let g = Tensor::from((candle::Storage::Cuda(g_storage), shape));
|
||||
|
||||
Ok((beta, g))
|
||||
}
|
||||
|
||||
match b.dtype() {
|
||||
DType::F16 => cuda_fwd::<half::f16>(b, a, a_log, dt_bias, 0),
|
||||
DType::BF16 => cuda_fwd::<half::bf16>(b, a, a_log, dt_bias, 1),
|
||||
other => candle_core::bail!(
|
||||
"fused_gdn_gating_cuda only supports f16/bf16, got {:?}",
|
||||
other
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(unused)]
|
||||
pub fn fused_gdn_gating_cuda(
|
||||
_b: &Tensor,
|
||||
_a: &Tensor,
|
||||
_a_log: &Tensor,
|
||||
_dt_bias: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature")
|
||||
}
|
||||
15
crates/neuron/src/cuda/mod.rs
Normal file
15
crates/neuron/src/cuda/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! CUDA kernels and their Rust wrappers.
|
||||
//!
|
||||
//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`)
|
||||
//! inference performance — the Gated DeltaNet kernels ported from
|
||||
//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu`
|
||||
//! file alongside this module; `build.rs` compiles them all into a
|
||||
//! static lib via `cudaforge` and links it under the `cuda` feature.
|
||||
//!
|
||||
//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM,
|
||||
//! etc.) they land here in their own `.cu` + `.rs` pairs.
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod ffi;
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod gdn;
|
||||
23
crates/neuron/src/harness/arch/mod.rs
Normal file
23
crates/neuron/src/harness/arch/mod.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
//! Custom architecture implementations.
|
||||
//!
|
||||
//! When candle-transformers ships a model family unchanged
|
||||
//! (`models::llama`, `models::qwen3`, `models::qwen3_moe`, etc.), the
|
||||
//! handler in `harness/candle.rs` just wraps the upstream type in a
|
||||
//! `ModelArch` variant.
|
||||
//!
|
||||
//! When candle has nothing for the architecture and we have to write
|
||||
//! it from scratch — Qwen3-Next / Qwen3.6 (`qwen3_5`) being the
|
||||
//! motivating example — the implementation lands here, one file per
|
||||
//! architecture.
|
||||
//!
|
||||
//! Each architecture module is expected to expose:
|
||||
//! - A `Config` type deserialised from the model's `config.json`
|
||||
//! (some architectures nest the real hyperparams under `text_config`,
|
||||
//! in which case the module owns the unwrapping).
|
||||
//! - A `ForCausalLM` struct with `new`, `forward(&mut self, x, offset)
|
||||
//! -> Result<Tensor>`, and `clear_kv_cache(&mut self)`.
|
||||
//!
|
||||
//! TP-aware analogues live in `harness/tp/tp_<family>.rs` and follow
|
||||
//! the pattern set by `tp_qwen3.rs`.
|
||||
|
||||
pub mod qwen3_5;
|
||||
117
crates/neuron/src/harness/arch/qwen3_5/decoder.rs
Normal file
117
crates/neuron/src/harness/arch/qwen3_5/decoder.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
//! Qwen3-Next decoder layer.
|
||||
//!
|
||||
//! Standard pre-norm transformer block (LN → attention → residual →
|
||||
//! LN → MLP → residual) where the attention slot dispatches on the
|
||||
//! per-layer `layer_types[i]` value in the config:
|
||||
//!
|
||||
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
|
||||
//! gate + RoPE + KV cache).
|
||||
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
|
||||
//! causal conv + per-head state).
|
||||
//!
|
||||
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
|
||||
//! linear_attention. `full_attention_interval` in the config is a
|
||||
//! hint; `layer_types` is authoritative.
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::TextConfig;
|
||||
use super::full_attn::Qwen3_5Attention;
|
||||
use super::linear_attn::GatedDeltaNet;
|
||||
use super::mlp::Qwen3_5MLP;
|
||||
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||
use super::rope::RotaryEmbedding;
|
||||
|
||||
/// One of the two attention flavours sitting in a decoder layer's
|
||||
/// attention slot. Full-attention layers need the rotary table and
|
||||
/// take an attention mask; linear-attention layers carry their own
|
||||
/// recurrent state and ignore the mask.
|
||||
enum AttentionKind {
|
||||
Full(Qwen3_5Attention),
|
||||
Linear(GatedDeltaNet),
|
||||
}
|
||||
|
||||
pub struct Qwen3_5DecoderLayer {
|
||||
input_layernorm: Qwen3_5RmsNorm,
|
||||
post_attention_layernorm: Qwen3_5RmsNorm,
|
||||
mlp: Qwen3_5MLP,
|
||||
attention: AttentionKind,
|
||||
}
|
||||
|
||||
impl Qwen3_5DecoderLayer {
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
layer_idx: usize,
|
||||
vb: &ShardedVarBuilder,
|
||||
) -> Result<Self> {
|
||||
let layer_type = cfg
|
||||
.layer_types
|
||||
.get(layer_idx)
|
||||
.map(String::as_str)
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"layer_types[{layer_idx}] missing (have {} entries)",
|
||||
cfg.layer_types.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
let attention = match layer_type {
|
||||
"full_attention" => {
|
||||
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
|
||||
}
|
||||
"linear_attention" => {
|
||||
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"unknown layer_type '{other}' for layer {layer_idx} (expected \
|
||||
'full_attention' or 'linear_attention')"
|
||||
),
|
||||
};
|
||||
|
||||
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
||||
&vb.pp("post_attention_layernorm"),
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
attention,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let h = self.input_layernorm.forward(x)?;
|
||||
let attn_out = match &mut self.attention {
|
||||
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
|
||||
// Linear attention ignores attn_mask + offset; its causal
|
||||
// structure is baked into the recurrent state lifecycle.
|
||||
AttentionKind::Linear(net) => net.forward(&h)?,
|
||||
};
|
||||
let x = (x + attn_out)?;
|
||||
let h2 = self.post_attention_layernorm.forward(&x)?;
|
||||
let h2 = self.mlp.forward(&h2)?;
|
||||
x + h2
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
match &mut self.attention {
|
||||
AttentionKind::Full(attn) => attn.clear_kv_cache(),
|
||||
AttentionKind::Linear(net) => net.clear_kv_cache(),
|
||||
}
|
||||
}
|
||||
}
|
||||
179
crates/neuron/src/harness/arch/qwen3_5/full_attn.rs
Normal file
179
crates/neuron/src/harness/arch/qwen3_5/full_attn.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
//! Qwen3-Next's `full_attention` layer.
|
||||
//!
|
||||
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
|
||||
//!
|
||||
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
|
||||
//! to `num_heads * head_dim * 2`. The second half is reshaped to
|
||||
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
|
||||
//! attention output is pointwise-multiplied by this gate before
|
||||
//! `o_proj`. Effectively a per-head per-position attenuation on
|
||||
//! the attention output.
|
||||
//!
|
||||
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
|
||||
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
|
||||
//! checkpoints expect the `(1 + w)` form.
|
||||
//!
|
||||
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
|
||||
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
|
||||
//! `rope::RotaryEmbedding`), and the usual causal mask.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::kv_cache::ConcatKvCache;
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use candle_transformers::utils::repeat_kv;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::TextConfig;
|
||||
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||
use super::rope::RotaryEmbedding;
|
||||
|
||||
pub struct Qwen3_5Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
q_norm: Qwen3_5RmsNorm,
|
||||
k_norm: Qwen3_5RmsNorm,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
kv_cache: ConcatKvCache,
|
||||
}
|
||||
|
||||
impl Qwen3_5Attention {
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
vb: &ShardedVarBuilder,
|
||||
) -> Result<Self> {
|
||||
let head_dim = cfg.head_dim;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
|
||||
anyhow::bail!(
|
||||
"num_attention_heads ({num_heads}) must be a positive multiple of \
|
||||
num_key_value_heads ({num_kv_heads})"
|
||||
);
|
||||
}
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
|
||||
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
|
||||
// the gate (see attn_output_gate notes above).
|
||||
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
|
||||
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
|
||||
|
||||
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||
|
||||
let hidden_size = head_dim * num_heads;
|
||||
let kv_cache = ConcatKvCache::new(2);
|
||||
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size,
|
||||
rotary,
|
||||
kv_cache,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l, _) = x.dims3()?;
|
||||
|
||||
// 1. q_proj — widened output, split into (query, gate).
|
||||
let q_raw = self
|
||||
.q_proj
|
||||
.forward(x)?
|
||||
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
|
||||
let q = q_raw.narrow(3, 0, self.head_dim)?;
|
||||
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
|
||||
// Flatten the gate's head dim back into hidden_size for the
|
||||
// post-attention pointwise multiply.
|
||||
let gate = gate
|
||||
.contiguous()?
|
||||
.reshape((b, l, self.num_heads * self.head_dim))?;
|
||||
|
||||
// 2. q_norm + k_norm + reshape to (B, H, L, D).
|
||||
let q = self.q_norm.forward(&q.contiguous()?)?;
|
||||
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
|
||||
|
||||
let k = self
|
||||
.k_proj
|
||||
.forward(x)?
|
||||
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
|
||||
let k = self.k_norm.forward(&k.contiguous()?)?;
|
||||
let k = k.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let v = self
|
||||
.v_proj
|
||||
.forward(x)?
|
||||
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
// 3. RoPE on q, k.
|
||||
let (q, k) = self.rotary.apply(&q, &k, offset)?;
|
||||
|
||||
// 4. KV cache.
|
||||
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||
|
||||
// 5. GQA repeat (cheap shape op).
|
||||
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
// 6. Scaled dot-product + causal mask.
|
||||
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
|
||||
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
if let Some(m) = attn_mask {
|
||||
scores = scores.broadcast_add(m)?;
|
||||
}
|
||||
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||
let ctx = probs.matmul(&v)?; // (B, H, L, D)
|
||||
|
||||
// 7. Reshape back, apply the output gate, project.
|
||||
let ctx = ctx
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?
|
||||
.reshape((b, l, self.hidden_size))?;
|
||||
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
|
||||
let gated = (ctx * gate_sig)?;
|
||||
self.o_proj.forward(&gated)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache.reset();
|
||||
}
|
||||
}
|
||||
|
||||
fn load_linear_no_bias(
|
||||
vb: &ShardedVarBuilder,
|
||||
name: &str,
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
) -> Result<Linear> {
|
||||
let weight = vb
|
||||
.pp(name)
|
||||
.get((out_dim, in_dim), "weight")
|
||||
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
793
crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs
Normal file
793
crates/neuron/src/harness/arch/qwen3_5/linear_attn.rs
Normal file
@@ -0,0 +1,793 @@
|
||||
//! Qwen3-Next's `linear_attention` layer: Gated DeltaNet.
|
||||
//!
|
||||
//! The recurrent linear-attention block that occupies 3 out of every 4
|
||||
//! decoder layers in Qwen3.6 (`layer_types[i] == "linear_attention"`).
|
||||
//! Implemented against the reference Python in
|
||||
//! `huggingface/transformers/src/transformers/models/qwen3_5/modeling_qwen3_5.py`
|
||||
//! (class `Qwen3_5GatedDeltaNet`).
|
||||
//!
|
||||
//! ## Block structure
|
||||
//!
|
||||
//! ```text
|
||||
//! x ── in_proj_qkv ── transpose ─► (B, conv_dim, L)
|
||||
//! │
|
||||
//! ┌──────────────── conv_state ──┤ prepend cached state (decode)
|
||||
//! ▼
|
||||
//! depthwise causal Conv1d (k=4) → SiLU
|
||||
//! │
|
||||
//! └─ split → q (k_dim), k (k_dim), v (v_dim) ─► per-head reshape
|
||||
//!
|
||||
//! x ── in_proj_z ────────────────► z (gate for the output RMSNorm)
|
||||
//! x ── in_proj_b ── sigmoid ─────► beta (per-head per-token update rate)
|
||||
//! x ── in_proj_a ── softplus ────► g (decay; see eqn below)
|
||||
//!
|
||||
//! g = -exp(A_log) * softplus(a + dt_bias) # discretisation
|
||||
//! beta = sigmoid(b)
|
||||
//!
|
||||
//! (q, k) ─── L2norm ─── delta rule loop ──── core_attn_out
|
||||
//! (per-token, per-head):
|
||||
//! state *= exp(g_t)
|
||||
//! mem = state^T · k_t
|
||||
//! delta = (v_t - mem) * beta_t
|
||||
//! state += outer(k_t, delta)
|
||||
//! out_t = state^T · q_t
|
||||
//!
|
||||
//! core_attn_out ── RMSNormGated(z) ── reshape ── out_proj ── y
|
||||
//! ```
|
||||
//!
|
||||
//! ## State
|
||||
//!
|
||||
//! Two tensors persist across decode steps:
|
||||
//! - `conv_state`: `(B, conv_dim, conv_kernel_size)` — left-padded
|
||||
//! tail of the input to the depthwise conv, so the next causal
|
||||
//! window has the right left-context.
|
||||
//! - `recurrent_state`: `(B, num_v_heads, head_k_dim, head_v_dim)` —
|
||||
//! the delta-rule outer-product memory.
|
||||
//!
|
||||
//! Both are cleared via [`GatedDeltaNet::clear_kv_cache`] at the start
|
||||
//! of every new request.
|
||||
//!
|
||||
//! ## Performance note
|
||||
//!
|
||||
//! This impl is the **recurrent** delta-rule for both prefill and
|
||||
//! decode — i.e. the algorithm in `torch_recurrent_gated_delta_rule`.
|
||||
//! Correctness-first. The chunked algorithm (chunk_size=64) in
|
||||
//! `torch_chunk_gated_delta_rule` is a perf optimisation for long
|
||||
//! prefill; can be added later without changing the surface.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::RopeParameters;
|
||||
use super::TextConfig;
|
||||
use super::rmsnorm::{Qwen3_5RmsNormGated, l2norm};
|
||||
|
||||
/// Per-rank, per-layer state for the linear-attention block.
|
||||
///
|
||||
/// `conv_state` is left-padded with zeros on first use; `recurrent_state`
|
||||
/// is initialised lazily to zeros once we know the batch size.
|
||||
#[derive(Default)]
|
||||
pub struct GatedDeltaNetState {
|
||||
pub conv_state: Option<Tensor>,
|
||||
pub recurrent_state: Option<Tensor>,
|
||||
}
|
||||
|
||||
pub struct GatedDeltaNet {
|
||||
// Projections.
|
||||
in_proj_qkv: Linear,
|
||||
in_proj_z: Linear,
|
||||
in_proj_b: Linear,
|
||||
in_proj_a: Linear,
|
||||
out_proj: Linear,
|
||||
|
||||
// Depthwise causal Conv1d weight; shape (conv_dim, 1, kernel_size).
|
||||
// No bias (Python sets bias=False).
|
||||
conv1d_weight: Tensor,
|
||||
|
||||
// Per-head discretisation params.
|
||||
dt_bias: Tensor,
|
||||
a_log: Tensor,
|
||||
|
||||
// Output norm + gate.
|
||||
norm: Qwen3_5RmsNormGated,
|
||||
|
||||
// Shape hyperparams (cached for forward).
|
||||
num_v_heads: usize,
|
||||
num_k_heads: usize,
|
||||
head_k_dim: usize,
|
||||
head_v_dim: usize,
|
||||
key_dim: usize,
|
||||
value_dim: usize,
|
||||
conv_dim: usize,
|
||||
conv_kernel_size: usize,
|
||||
|
||||
// Recurrent state held inline. Each request resets via
|
||||
// `clear_kv_cache`; otherwise the state persists across forwards
|
||||
// and the per-token offset advances naturally.
|
||||
state: GatedDeltaNetState,
|
||||
}
|
||||
|
||||
impl GatedDeltaNet {
|
||||
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||
let num_v_heads = cfg.linear_num_value_heads;
|
||||
let num_k_heads = cfg.linear_num_key_heads;
|
||||
let head_k_dim = cfg.linear_key_head_dim;
|
||||
let head_v_dim = cfg.linear_value_head_dim;
|
||||
let conv_kernel_size = cfg.linear_conv_kernel_dim;
|
||||
|
||||
if num_v_heads == 0 || num_k_heads == 0 {
|
||||
anyhow::bail!(
|
||||
"Qwen3-Next linear_num_*_heads must be set; got v={num_v_heads}, k={num_k_heads}"
|
||||
);
|
||||
}
|
||||
if !num_v_heads.is_multiple_of(num_k_heads) {
|
||||
anyhow::bail!(
|
||||
"linear_num_value_heads ({num_v_heads}) must be a multiple of \
|
||||
linear_num_key_heads ({num_k_heads}) for GQA-style head expansion"
|
||||
);
|
||||
}
|
||||
|
||||
let key_dim = head_k_dim * num_k_heads;
|
||||
let value_dim = head_v_dim * num_v_heads;
|
||||
let conv_dim = key_dim * 2 + value_dim;
|
||||
|
||||
// ----- Linear projections (all `bias=False` in the reference). -----
|
||||
let in_proj_qkv = load_linear_no_bias(vb, "in_proj_qkv", cfg.hidden_size, conv_dim)?;
|
||||
let in_proj_z = load_linear_no_bias(vb, "in_proj_z", cfg.hidden_size, value_dim)?;
|
||||
let in_proj_b = load_linear_no_bias(vb, "in_proj_b", cfg.hidden_size, num_v_heads)?;
|
||||
let in_proj_a = load_linear_no_bias(vb, "in_proj_a", cfg.hidden_size, num_v_heads)?;
|
||||
let out_proj = load_linear_no_bias(vb, "out_proj", value_dim, cfg.hidden_size)?;
|
||||
|
||||
// ----- Conv1d weight (depthwise, bias=False). -----
|
||||
let conv1d_weight = vb
|
||||
.pp("conv1d")
|
||||
.get((conv_dim, 1, conv_kernel_size), "weight")
|
||||
.with_context(|| format!("load '{}/conv1d/weight'", vb.prefix()))?;
|
||||
|
||||
// ----- dt_bias + A_log: per-head 1D params. -----
|
||||
let dt_bias = vb
|
||||
.get(num_v_heads, "dt_bias")
|
||||
.with_context(|| format!("load '{}/dt_bias'", vb.prefix()))?;
|
||||
let a_log = vb
|
||||
.get(num_v_heads, "A_log")
|
||||
.with_context(|| format!("load '{}/A_log'", vb.prefix()))?;
|
||||
|
||||
// ----- Output gated RMSNorm (per-head_v_dim). -----
|
||||
let norm = Qwen3_5RmsNormGated::load(&vb.pp("norm"), head_v_dim, cfg.rms_norm_eps)?;
|
||||
|
||||
Ok(Self {
|
||||
in_proj_qkv,
|
||||
in_proj_z,
|
||||
in_proj_b,
|
||||
in_proj_a,
|
||||
out_proj,
|
||||
conv1d_weight,
|
||||
dt_bias,
|
||||
a_log,
|
||||
norm,
|
||||
num_v_heads,
|
||||
num_k_heads,
|
||||
head_k_dim,
|
||||
head_v_dim,
|
||||
key_dim,
|
||||
value_dim,
|
||||
conv_dim,
|
||||
conv_kernel_size,
|
||||
state: GatedDeltaNetState::default(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.state = GatedDeltaNetState::default();
|
||||
}
|
||||
|
||||
/// `x` shape: `(B, L, hidden_size)`. Returns the same shape.
|
||||
pub fn forward(&mut self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let (batch_size, seq_len, _) = x.dims3()?;
|
||||
let dtype = x.dtype();
|
||||
let device = x.device().clone();
|
||||
|
||||
// ----- Projections. -----
|
||||
// mixed_qkv: (B, L, conv_dim)
|
||||
let mixed_qkv = self.in_proj_qkv.forward(x)?;
|
||||
// (B, conv_dim, L) for the conv1d.
|
||||
let mixed_qkv_chw = mixed_qkv.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
// z: (B, L, value_dim) → (B, L, num_v_heads, head_v_dim)
|
||||
let z = self.in_proj_z.forward(x)?.reshape((
|
||||
batch_size,
|
||||
seq_len,
|
||||
self.num_v_heads,
|
||||
self.head_v_dim,
|
||||
))?;
|
||||
|
||||
// b, a: (B, L, num_v_heads)
|
||||
let b = self.in_proj_b.forward(x)?;
|
||||
let a = self.in_proj_a.forward(x)?;
|
||||
|
||||
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
|
||||
// Dispatches to a cuda kernel that fuses conv1d + silu when
|
||||
// available; falls back to candle's `conv1d` + `silu` on cpu.
|
||||
let (conv_out, new_state) = run_causal_conv1d(
|
||||
&mixed_qkv_chw,
|
||||
&self.conv1d_weight,
|
||||
self.state.conv_state.take(),
|
||||
batch_size,
|
||||
self.conv_dim,
|
||||
seq_len,
|
||||
self.conv_kernel_size,
|
||||
)?;
|
||||
self.state.conv_state = Some(new_state);
|
||||
// Back to (B, L, conv_dim).
|
||||
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
// ----- Split into q, k, v. -----
|
||||
let q = mixed_qkv.narrow(2, 0, self.key_dim)?;
|
||||
let k = mixed_qkv.narrow(2, self.key_dim, self.key_dim)?;
|
||||
let v = mixed_qkv.narrow(2, 2 * self.key_dim, self.value_dim)?;
|
||||
|
||||
let q = q.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
|
||||
let k = k.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
|
||||
let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?;
|
||||
|
||||
// ----- beta + g (per-head, per-token gates). -----
|
||||
// Fused on cuda; per-op Rust on cpu. Both paths produce:
|
||||
// beta = sigmoid(b)
|
||||
// g = -exp(A_log) * softplus(a + dt_bias)
|
||||
let (beta, g) = run_fused_gating(&b, &a, &self.a_log, &self.dt_bias)?;
|
||||
|
||||
// ----- GQA-style key expansion if num_v_heads > num_k_heads. -----
|
||||
let (q, k) = if self.num_v_heads > self.num_k_heads {
|
||||
let rep = self.num_v_heads / self.num_k_heads;
|
||||
(
|
||||
repeat_interleave(&q, rep, 2)?,
|
||||
repeat_interleave(&k, rep, 2)?,
|
||||
)
|
||||
} else {
|
||||
(q, k)
|
||||
};
|
||||
|
||||
// ----- L2-norm on q, k (use_qk_l2norm_in_kernel=True in ref). -----
|
||||
let q = l2norm(&q, 1e-6)?;
|
||||
let k = l2norm(&k, 1e-6)?;
|
||||
|
||||
// ----- Recurrent delta rule. -----
|
||||
// Inputs: q, k (B, L, H, D_k); v (B, L, H, D_v); g (B, L, H); beta (B, L, H).
|
||||
// The reference transposes to (B, H, L, D) before the loop. We
|
||||
// do the same — it makes per-token indexing trivial.
|
||||
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D_k)
|
||||
let k = k.transpose(1, 2)?.contiguous()?;
|
||||
let v = v.transpose(1, 2)?.contiguous()?; // (B, H, L, D_v)
|
||||
let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
||||
let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L)
|
||||
|
||||
// Pre-scale q by 1/sqrt(D_k) once. Everything goes to f32 here
|
||||
// since the delta rule mixes broadcast_mul ops that candle won't
|
||||
// accept across mixed dtypes. On the cuda gating path both beta
|
||||
// and g come back in model dtype; on the cpu path g is already
|
||||
// f32 — both casts are cheap idempotent ops.
|
||||
let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt();
|
||||
let q = (q.to_dtype(candle_core::DType::F32)? * scale)?;
|
||||
let k = k.to_dtype(candle_core::DType::F32)?;
|
||||
let v = v.to_dtype(candle_core::DType::F32)?;
|
||||
let g = g.to_dtype(candle_core::DType::F32)?;
|
||||
let beta = beta.to_dtype(candle_core::DType::F32)?;
|
||||
|
||||
// Initialise the recurrent state from cache or zeros.
|
||||
let state_init = match self.state.recurrent_state.take() {
|
||||
Some(s) => s.to_dtype(candle_core::DType::F32)?,
|
||||
None => Tensor::zeros(
|
||||
(
|
||||
batch_size,
|
||||
self.num_v_heads,
|
||||
self.head_k_dim,
|
||||
self.head_v_dim,
|
||||
),
|
||||
candle_core::DType::F32,
|
||||
&device,
|
||||
)?,
|
||||
};
|
||||
|
||||
// The delta-rule body: cuda-accelerated `gated_delta_rule_recurrence`
|
||||
// kernel when we have a cuda device + the kernels are linked in,
|
||||
// pure-Rust per-token fallback otherwise.
|
||||
let (core_attn_out, new_state) = run_delta_rule(
|
||||
&q,
|
||||
&k,
|
||||
&v,
|
||||
&g,
|
||||
&beta,
|
||||
state_init,
|
||||
batch_size,
|
||||
self.num_v_heads,
|
||||
seq_len,
|
||||
self.head_k_dim,
|
||||
self.head_v_dim,
|
||||
)?;
|
||||
// Stash the updated recurrent state for the next call.
|
||||
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
|
||||
|
||||
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
|
||||
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
|
||||
let core_attn_out = core_attn_out.to_dtype(dtype)?;
|
||||
let core_attn_flat =
|
||||
core_attn_out.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
|
||||
let z_flat = z.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
|
||||
|
||||
// RMSNormGated: (out * silu(z) * weight) with the norm.
|
||||
let normed = self.norm.forward(&core_attn_flat, &z_flat)?;
|
||||
let normed = normed.reshape((batch_size, seq_len, self.num_v_heads * self.head_v_dim))?;
|
||||
|
||||
// Output projection: (B, L, value_dim) → (B, L, hidden_size).
|
||||
self.out_proj.forward(&normed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the per-token delta-rule recurrence.
|
||||
///
|
||||
/// `q`, `k`: `(B, H, L, D_k)` (F32). `v`: `(B, H, L, D_v)`. `g`,
|
||||
/// `beta`: `(B, H, L)`. `state`: `(B, H, D_k, D_v)`.
|
||||
///
|
||||
/// Returns `(core_attn_out: (B, H, L, D_v), state: (B, H, D_k, D_v))`,
|
||||
/// both F32. Caller is responsible for cast back to model dtype.
|
||||
///
|
||||
/// Cuda path: dispatches to the `gated_delta_rule_recurrence` kernel
|
||||
/// ported from `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/gdn.cu`.
|
||||
/// All five inputs must be cuda f32 tensors. The kernel is V-tiled
|
||||
/// with compile-time BK; one block per (V-tile, batch*head) and one
|
||||
/// thread per V-column. Each thread holds BK state floats in
|
||||
/// registers — eliminates the launch-overhead floor we hit with
|
||||
/// candle's per-op dispatch (was ~12s/token on Qwen3.6-27B).
|
||||
///
|
||||
/// CPU path: pure-Rust per-token loop. Correct, slow.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn run_delta_rule(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
g: &Tensor,
|
||||
beta: &Tensor,
|
||||
state: Tensor,
|
||||
batch_size: usize,
|
||||
num_heads: usize,
|
||||
seq_len: usize,
|
||||
head_k_dim: usize,
|
||||
head_v_dim: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
// Only dispatch to the kernel if the inputs are on a CUDA
|
||||
// device — CPU tests fall back to the Rust loop below.
|
||||
if q.device().is_cuda() {
|
||||
return run_delta_rule_cuda(
|
||||
q, k, v, g, beta, state, batch_size, num_heads, seq_len, head_k_dim, head_v_dim,
|
||||
);
|
||||
}
|
||||
}
|
||||
let _ = (batch_size, num_heads, head_k_dim, head_v_dim);
|
||||
run_delta_rule_rust(q, k, v, g, beta, state, seq_len)
|
||||
}
|
||||
|
||||
/// CUDA path. Flattens (B, H, ...) → (BH, ...) at the kernel boundary
|
||||
/// (the kernel uses BH = batch*heads as its outer batch axis) and
|
||||
/// reshapes the kernel's outputs back to (B, H, ...) for the caller.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_delta_rule_cuda(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
g: &Tensor,
|
||||
beta: &Tensor,
|
||||
state: Tensor,
|
||||
batch_size: usize,
|
||||
num_heads: usize,
|
||||
seq_len: usize,
|
||||
head_k_dim: usize,
|
||||
head_v_dim: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let q_bh = q.flatten(0, 1)?.contiguous()?;
|
||||
let k_bh = k.flatten(0, 1)?.contiguous()?;
|
||||
let v_bh = v.flatten(0, 1)?.contiguous()?;
|
||||
let g_bh = g.flatten(0, 1)?.contiguous()?;
|
||||
let beta_bh = beta.flatten(0, 1)?.contiguous()?;
|
||||
let mut state_bh = state.flatten(0, 1)?.contiguous()?;
|
||||
// For long prefills, the chunked kernel (BT=64) processes a chunk
|
||||
// of tokens at a time instead of one-by-one — same delta-rule math,
|
||||
// far fewer block launches. Threshold matches mistralrs.
|
||||
const CHUNK_THRESHOLD: usize = 64;
|
||||
let output_bh = if seq_len >= CHUNK_THRESHOLD {
|
||||
crate::cuda::gdn::chunked_gated_delta_rule_recurrence_cuda(
|
||||
&q_bh,
|
||||
&k_bh,
|
||||
&v_bh,
|
||||
&g_bh,
|
||||
&beta_bh,
|
||||
&mut state_bh,
|
||||
)?
|
||||
} else {
|
||||
crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
|
||||
&q_bh,
|
||||
&k_bh,
|
||||
&v_bh,
|
||||
&g_bh,
|
||||
&beta_bh,
|
||||
&mut state_bh,
|
||||
)?
|
||||
};
|
||||
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
|
||||
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
|
||||
Ok((core_attn_out, new_state))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_delta_rule_rust(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
g: &Tensor,
|
||||
beta: &Tensor,
|
||||
mut state: Tensor,
|
||||
seq_len: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
use candle_core::IndexOp;
|
||||
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
|
||||
for t in 0..seq_len {
|
||||
let q_t = q.i((.., .., t, ..))?;
|
||||
let k_t = k.i((.., .., t, ..))?;
|
||||
let v_t = v.i((.., .., t, ..))?;
|
||||
let g_t = g.i((.., .., t))?;
|
||||
let beta_t = beta.i((.., .., t))?;
|
||||
let decay = g_t
|
||||
.exp()?
|
||||
.unsqueeze(candle_core::D::Minus1)?
|
||||
.unsqueeze(candle_core::D::Minus1)?;
|
||||
state = state.broadcast_mul(&decay)?;
|
||||
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?;
|
||||
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
|
||||
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?;
|
||||
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
|
||||
let delta_row = delta.unsqueeze(2)?;
|
||||
let outer = k_col.broadcast_mul(&delta_row)?;
|
||||
state = (state + outer)?;
|
||||
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?;
|
||||
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
|
||||
outputs.push(out_t.unsqueeze(2)?);
|
||||
}
|
||||
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
|
||||
Ok((core_attn_out, state))
|
||||
}
|
||||
|
||||
/// Depthwise causal conv1d + SiLU, with rolling `conv_state`.
|
||||
///
|
||||
/// `x`: `(B, conv_dim, L)` model dtype (f16/bf16 on cuda, anything on cpu).
|
||||
/// `weight`: `(conv_dim, 1, kernel_size)` model dtype.
|
||||
/// `conv_state`: `Some((B, conv_dim, kernel_size))` for decode continuation,
|
||||
/// or `None` for fresh prefill.
|
||||
///
|
||||
/// Returns `(conv_out: (B, conv_dim, L), new_conv_state: (B, conv_dim, kernel_size))`.
|
||||
/// SiLU is baked in.
|
||||
///
|
||||
/// Cuda path: dispatches to `causal_conv1d_update` (decode, seq_len=1 with
|
||||
/// existing state) or `causal_conv1d_full` (prefill / first call), both
|
||||
/// ported from mistralrs `gdn.cu`. Each kernel fuses the depthwise conv
|
||||
/// and SiLU activation in one launch — that's ~4× fewer cuda launches per
|
||||
/// linear-attention layer than the candle `conv1d` + `silu` combo.
|
||||
///
|
||||
/// CPU path: the original prepend-narrow-conv1d-silu sequence.
|
||||
pub(crate) fn run_causal_conv1d(
|
||||
x: &Tensor,
|
||||
weight: &Tensor,
|
||||
conv_state: Option<Tensor>,
|
||||
batch_size: usize,
|
||||
conv_dim: usize,
|
||||
seq_len: usize,
|
||||
conv_kernel_size: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
if x.device().is_cuda() {
|
||||
return run_causal_conv1d_cuda(
|
||||
x,
|
||||
weight,
|
||||
conv_state,
|
||||
batch_size,
|
||||
conv_dim,
|
||||
seq_len,
|
||||
conv_kernel_size,
|
||||
);
|
||||
}
|
||||
}
|
||||
run_causal_conv1d_rust(
|
||||
x,
|
||||
weight,
|
||||
conv_state,
|
||||
batch_size,
|
||||
conv_dim,
|
||||
seq_len,
|
||||
conv_kernel_size,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn run_causal_conv1d_cuda(
|
||||
x: &Tensor,
|
||||
weight: &Tensor,
|
||||
conv_state: Option<Tensor>,
|
||||
batch_size: usize,
|
||||
conv_dim: usize,
|
||||
seq_len: usize,
|
||||
conv_kernel_size: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
// Kernel expects weight as (conv_dim, kernel_size) — squeeze the
|
||||
// depthwise channel-multiplier dim.
|
||||
let w = weight.squeeze(1)?.to_dtype(x.dtype())?.contiguous()?;
|
||||
|
||||
// Decode path: seq_len == 1 AND we have an existing conv_state.
|
||||
// Otherwise (prefill or fresh-start decode), use the full path which
|
||||
// zero-pads on the left internally.
|
||||
if let Some(cs) = conv_state
|
||||
&& seq_len == 1
|
||||
{
|
||||
let cs = cs.contiguous()?;
|
||||
let (output, new_conv_state) =
|
||||
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &cs, conv_kernel_size, true)?;
|
||||
return Ok((output, new_conv_state));
|
||||
}
|
||||
|
||||
// Prefill / fresh-start: the kernel ignores any prior conv_state and
|
||||
// zero-pads. If we had a non-zero prior state and >1 input tokens
|
||||
// (multi-turn continuation), we'd need to fall back to Rust. Match
|
||||
// mistralrs's behaviour: fresh prefill always.
|
||||
let device = x.device().clone();
|
||||
let zeros_cs = Tensor::zeros((batch_size, conv_dim, conv_kernel_size), x.dtype(), &device)?;
|
||||
let (output, new_conv_state) =
|
||||
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &zeros_cs, conv_kernel_size, false)?;
|
||||
Ok((output, new_conv_state))
|
||||
}
|
||||
|
||||
/// Fused GDN gating: computes `beta = sigmoid(b)` and
|
||||
/// `g = -exp(a_log) * softplus(a + dt_bias)` together.
|
||||
///
|
||||
/// `b`, `a`: `(B, L, num_heads)` model dtype.
|
||||
/// `a_log`, `dt_bias`: `(num_heads,)` model dtype (cast to f32 internally).
|
||||
///
|
||||
/// Returns `(beta, g)` both in model dtype on the cuda path, both in f32
|
||||
/// on the cpu fallback. The caller casts to f32 before the delta rule.
|
||||
///
|
||||
/// Cuda path: dispatches to `fused_gdn_gating_cuda` — one kernel
|
||||
/// replaces sigmoid + neg(exp) + softplus + broadcast_mul (≈10 candle
|
||||
/// launches per layer).
|
||||
pub(crate) fn run_fused_gating(
|
||||
b: &Tensor,
|
||||
a: &Tensor,
|
||||
a_log: &Tensor,
|
||||
dt_bias: &Tensor,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
if b.device().is_cuda() {
|
||||
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?.contiguous()?;
|
||||
let dt_bias_f32 = dt_bias.to_dtype(candle_core::DType::F32)?.contiguous()?;
|
||||
return crate::cuda::gdn::fused_gdn_gating_cuda(b, a, &a_log_f32, &dt_bias_f32);
|
||||
}
|
||||
}
|
||||
run_fused_gating_rust(b, a, a_log, dt_bias)
|
||||
}
|
||||
|
||||
fn run_fused_gating_rust(
|
||||
b: &Tensor,
|
||||
a: &Tensor,
|
||||
a_log: &Tensor,
|
||||
dt_bias: &Tensor,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let beta = candle_nn::ops::sigmoid(b)?;
|
||||
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?;
|
||||
let neg_a_exp = a_log_f32.exp()?.neg()?;
|
||||
let dt_b_f32 = dt_bias.to_dtype(candle_core::DType::F32)?;
|
||||
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
|
||||
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
|
||||
let softplus_val = softplus(&a_plus_dt)?;
|
||||
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
|
||||
let g = neg_a_exp_b.broadcast_mul(&softplus_val)?;
|
||||
Ok((beta, g))
|
||||
}
|
||||
|
||||
fn run_causal_conv1d_rust(
|
||||
x: &Tensor,
|
||||
weight: &Tensor,
|
||||
conv_state: Option<Tensor>,
|
||||
batch_size: usize,
|
||||
conv_dim: usize,
|
||||
seq_len: usize,
|
||||
conv_kernel_size: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let dtype = x.dtype();
|
||||
let device = x.device().clone();
|
||||
|
||||
let prepended = match &conv_state {
|
||||
Some(prev) => Tensor::cat(&[prev, x], 2)?,
|
||||
None => x.clone(),
|
||||
};
|
||||
let prep_len = prepended.dims()[2];
|
||||
|
||||
let new_state = if prep_len >= conv_kernel_size {
|
||||
prepended.narrow(2, prep_len - conv_kernel_size, conv_kernel_size)?
|
||||
} else {
|
||||
let pad = Tensor::zeros(
|
||||
(batch_size, conv_dim, conv_kernel_size - prep_len),
|
||||
dtype,
|
||||
&device,
|
||||
)?;
|
||||
Tensor::cat(&[&pad, &prepended], 2)?
|
||||
};
|
||||
|
||||
let conv_out = prepended.conv1d(weight, conv_kernel_size - 1, 1, 1, conv_dim)?;
|
||||
let conv_out = conv_out.narrow(2, 0, prep_len)?;
|
||||
let conv_out = candle_nn::ops::silu(&conv_out)?;
|
||||
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
|
||||
Ok((conv_out, new_state))
|
||||
}
|
||||
|
||||
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
|
||||
/// the standard `[out, in]` order.
|
||||
fn load_linear_no_bias(
|
||||
vb: &ShardedVarBuilder,
|
||||
name: &str,
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
) -> Result<Linear> {
|
||||
let weight = vb
|
||||
.pp(name)
|
||||
.get((out_dim, in_dim), "weight")
|
||||
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
|
||||
/// Numerically-stable `softplus(x) = ln(1 + exp(x))`. Matches PyTorch's
|
||||
/// `F.softplus` default (beta=1, threshold=20: for large positive x,
|
||||
/// returns x as-is to avoid overflow in the exp).
|
||||
pub(crate) fn softplus(x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let threshold = 20.0_f64;
|
||||
let big = x.ge(threshold)?; // Tensor<u8> mask
|
||||
let safe = x.minimum(&x.affine(0.0, 0.0)?.affine(0.0, threshold)?)?; // min(x, threshold)
|
||||
let small = ((safe.exp()? + 1.0_f64)?).log()?;
|
||||
// Select x where big, else small.
|
||||
big.where_cond(x, &small)
|
||||
}
|
||||
|
||||
/// `repeat_interleave` along a single dim. Candle has no built-in for
|
||||
/// this; emulate with unsqueeze + expand + reshape.
|
||||
pub(crate) fn repeat_interleave(
|
||||
x: &Tensor,
|
||||
repeats: usize,
|
||||
dim: usize,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
if repeats == 1 {
|
||||
return Ok(x.clone());
|
||||
}
|
||||
let mut shape = x.dims().to_vec();
|
||||
let orig = shape[dim];
|
||||
shape.insert(dim + 1, repeats);
|
||||
let mut expanded_shape = shape.clone();
|
||||
expanded_shape[dim + 1] = repeats;
|
||||
let x = x.unsqueeze(dim + 1)?;
|
||||
let x = x.expand(expanded_shape)?;
|
||||
let mut out_shape = x.dims().to_vec();
|
||||
out_shape.remove(dim + 1);
|
||||
out_shape[dim] = orig * repeats;
|
||||
x.contiguous()?.reshape(out_shape)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle_core::{DType, Device};
|
||||
|
||||
#[test]
|
||||
fn softplus_small_x() {
|
||||
// softplus(0) = ln(2) ≈ 0.6931
|
||||
let x = Tensor::new(&[0.0_f32], &Device::Cpu).unwrap();
|
||||
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
|
||||
assert!((out[0] - 2.0_f32.ln()).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softplus_large_x_returns_x() {
|
||||
// For x = 30, softplus(x) ≈ x (the threshold branch).
|
||||
let x = Tensor::new(&[30.0_f32], &Device::Cpu).unwrap();
|
||||
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
|
||||
assert!((out[0] - 30.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn repeat_interleave_doubles_dim() {
|
||||
let x = Tensor::new(&[[1.0_f32, 2.0], [3.0, 4.0]], &Device::Cpu).unwrap(); // shape (2, 2)
|
||||
let out = repeat_interleave(&x, 2, 1).unwrap(); // each col duplicated
|
||||
let v: Vec<Vec<f32>> = out.to_vec2().unwrap();
|
||||
// Row 0: 1, 1, 2, 2
|
||||
// Row 1: 3, 3, 4, 4
|
||||
assert_eq!(v[0], vec![1.0, 1.0, 2.0, 2.0]);
|
||||
assert_eq!(v[1], vec![3.0, 3.0, 4.0, 4.0]);
|
||||
}
|
||||
|
||||
/// Sanity: the recurrent path produces a finite tensor of the right
|
||||
/// shape on tiny dimensions. Doesn't validate numerical correctness
|
||||
/// against the Python reference — that would need a fixed-weight
|
||||
/// fixture to compare against. Catches structural mistakes
|
||||
/// (broadcasting shapes, off-by-one slices) early.
|
||||
#[test]
|
||||
fn forward_smoke_with_tiny_dimensions() {
|
||||
let dev = Device::Cpu;
|
||||
let dtype = DType::F32;
|
||||
let (b, l) = (1, 3);
|
||||
let cfg = TextConfig {
|
||||
vocab_size: 100,
|
||||
hidden_size: 16,
|
||||
intermediate_size: 32,
|
||||
num_hidden_layers: 1,
|
||||
num_attention_heads: 4,
|
||||
num_key_value_heads: 1,
|
||||
head_dim: 4,
|
||||
max_position_embeddings: 32,
|
||||
rope_parameters: RopeParameters {
|
||||
rope_theta: 10000.0,
|
||||
partial_rotary_factor: 1.0,
|
||||
rope_type: None,
|
||||
},
|
||||
rms_norm_eps: 1e-6,
|
||||
tie_word_embeddings: false,
|
||||
attn_output_gate: true,
|
||||
layer_types: vec!["linear_attention".into()],
|
||||
full_attention_interval: Some(4),
|
||||
hidden_act: "silu".into(),
|
||||
linear_num_value_heads: 4,
|
||||
linear_num_key_heads: 2,
|
||||
linear_key_head_dim: 4,
|
||||
linear_value_head_dim: 4,
|
||||
linear_conv_kernel_dim: 4,
|
||||
};
|
||||
|
||||
// Build a synthetic VarBuilder with all-zeros weights.
|
||||
// Easier path: skip the load and construct GatedDeltaNet
|
||||
// manually by hand-rolling the Linear/Tensor inputs.
|
||||
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &dev).unwrap();
|
||||
let key_dim = cfg.linear_key_head_dim * cfg.linear_num_key_heads;
|
||||
let value_dim = cfg.linear_value_head_dim * cfg.linear_num_value_heads;
|
||||
let conv_dim = key_dim * 2 + value_dim;
|
||||
let mut net = GatedDeltaNet {
|
||||
in_proj_qkv: Linear::new(zeros(&[conv_dim, cfg.hidden_size]), None),
|
||||
in_proj_z: Linear::new(zeros(&[value_dim, cfg.hidden_size]), None),
|
||||
in_proj_b: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
|
||||
in_proj_a: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
|
||||
out_proj: Linear::new(zeros(&[cfg.hidden_size, value_dim]), None),
|
||||
conv1d_weight: zeros(&[conv_dim, 1, cfg.linear_conv_kernel_dim]),
|
||||
dt_bias: zeros(&[cfg.linear_num_value_heads]),
|
||||
a_log: zeros(&[cfg.linear_num_value_heads]),
|
||||
norm: {
|
||||
let weight = Tensor::ones(&[cfg.linear_value_head_dim], dtype, &dev).unwrap();
|
||||
Qwen3_5RmsNormGated::from_weight(weight, cfg.rms_norm_eps)
|
||||
},
|
||||
num_v_heads: cfg.linear_num_value_heads,
|
||||
num_k_heads: cfg.linear_num_key_heads,
|
||||
head_k_dim: cfg.linear_key_head_dim,
|
||||
head_v_dim: cfg.linear_value_head_dim,
|
||||
key_dim,
|
||||
value_dim,
|
||||
conv_dim,
|
||||
conv_kernel_size: cfg.linear_conv_kernel_dim,
|
||||
state: GatedDeltaNetState::default(),
|
||||
};
|
||||
|
||||
let x = Tensor::ones(&[b, l, cfg.hidden_size], dtype, &dev).unwrap();
|
||||
let y = net.forward(&x).unwrap();
|
||||
assert_eq!(y.dims(), &[b, l, cfg.hidden_size]);
|
||||
// All zero weights → output should be zero. Confirms no NaN/Inf
|
||||
// poisoning from the f32 promotions.
|
||||
let v: Vec<f32> = y.flatten_all().unwrap().to_vec1().unwrap();
|
||||
assert!(v.iter().all(|x| x.is_finite()));
|
||||
}
|
||||
}
|
||||
53
crates/neuron/src/harness/arch/qwen3_5/mlp.rs
Normal file
53
crates/neuron/src/harness/arch/qwen3_5/mlp.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
//! SwiGLU MLP block for Qwen3-Next.
|
||||
//!
|
||||
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
|
||||
//! no bias on any of the three projections.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
|
||||
use super::TextConfig;
|
||||
|
||||
pub struct Qwen3_5MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
}
|
||||
|
||||
impl Qwen3_5MLP {
|
||||
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let i = cfg.intermediate_size;
|
||||
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
|
||||
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
|
||||
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Qwen3_5MLP {
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
|
||||
let rhs = self.up_proj.forward(x)?;
|
||||
self.down_proj.forward(&(lhs * rhs)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn load_linear_no_bias(
|
||||
vb: &ShardedVarBuilder,
|
||||
name: &str,
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
) -> Result<Linear> {
|
||||
let weight = vb
|
||||
.pp(name)
|
||||
.get((out_dim, in_dim), "weight")
|
||||
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
397
crates/neuron/src/harness/arch/qwen3_5/mod.rs
Normal file
397
crates/neuron/src/harness/arch/qwen3_5/mod.rs
Normal file
@@ -0,0 +1,397 @@
|
||||
//! Qwen3-Next (`model_type = "qwen3_5"`) architecture — Qwen3.6's
|
||||
//! upstream architecture revision.
|
||||
//!
|
||||
//! ## Naming
|
||||
//!
|
||||
//! The model release this targets is `Qwen/Qwen3.6-*` but the
|
||||
//! architecture name in HuggingFace's `config.json` is `qwen3_5`.
|
||||
//! mistralrs calls the same architecture `qwen3_next`; that label
|
||||
//! ages poorly the next time Qwen ship a new arch, so we key on the
|
||||
//! canonical `qwen3_5` from the model's own config.
|
||||
//!
|
||||
//! ## Status
|
||||
//!
|
||||
//! **Single-GPU dense path is real**. Both attention flavours
|
||||
//! (`full_attention` with the output-gated GQA causal attention and
|
||||
//! `linear_attention` with the Gated DeltaNet recurrent block) are
|
||||
//! implemented. The model loads from upstream safetensors via the
|
||||
//! existing `load_arch_dense` dispatch and runs forward end to end.
|
||||
//!
|
||||
//! Numerical correctness vs the reference Python is **not yet
|
||||
//! validated** — the structural code path is right, weight tensor
|
||||
//! names match the upstream layout, shapes flow through cleanly, but
|
||||
//! the Tbilisi probe (and any other downstream test) is the next
|
||||
//! step. Likely places a bug would surface:
|
||||
//! - Per-rank vs per-token-position offsets in the recurrent delta
|
||||
//! rule (`linear_attn.rs`).
|
||||
//! - Off-by-one in the conv state continuation across decode steps.
|
||||
//! - RoPE phase mismatch from MRoPE simplification (we treat the
|
||||
//! three position grids as collapsed, which is correct only for
|
||||
//! text-only inference).
|
||||
//!
|
||||
//! ## Submodules
|
||||
//!
|
||||
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
|
||||
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
|
||||
//! `l2norm` helper.
|
||||
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
|
||||
//! rotate-half).
|
||||
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
|
||||
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
|
||||
//! widening on `q_proj`.
|
||||
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
|
||||
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
|
||||
//! delta rule → RMSNormGated → out_proj).
|
||||
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
|
||||
//! two attention flavours per layer index.
|
||||
//!
|
||||
//! ## Open work
|
||||
//!
|
||||
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
|
||||
//! Sharding strategy diverges by layer type:
|
||||
//! - Full-attention layers: column-parallel q/k/v (including the
|
||||
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
|
||||
//! `tp_qwen3.rs`.
|
||||
//! - Linear-attention layers: the recurrent state is per-V-head, so
|
||||
//! V-head-dimension sharding works cleanly — split `num_v_heads`
|
||||
//! across ranks (`num_v_heads / world_size` per rank), shard
|
||||
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
|
||||
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
|
||||
//! `dt_bias` per-head params shard with the heads.
|
||||
//!
|
||||
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
|
||||
//! per-token recurrent path for prefill too — correct but O(L).
|
||||
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
|
||||
//! prefill substantially with no surface change.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||
use candle_nn::Embedding;
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod decoder;
|
||||
pub mod full_attn;
|
||||
pub mod linear_attn;
|
||||
pub mod mlp;
|
||||
pub mod rmsnorm;
|
||||
pub mod rope;
|
||||
|
||||
use decoder::Qwen3_5DecoderLayer;
|
||||
use rmsnorm::Qwen3_5RmsNorm;
|
||||
use rope::RotaryEmbedding;
|
||||
|
||||
/// `model_type` we deserialise from `config.json`. Const so the
|
||||
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
|
||||
/// magic strings.
|
||||
pub const MODEL_TYPE: &str = "qwen3_5";
|
||||
|
||||
/// Top-level shape of Qwen3-Next's `config.json`. The real
|
||||
/// hyperparameters live in `text_config`; the rest is multimodal /
|
||||
/// tokeniser glue we don't need for the language-model forward.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
/// Always `"qwen3_5"` for this architecture. Kept on the struct
|
||||
/// so the (eventual) dispatch / logging code can show it without
|
||||
/// re-parsing the JSON.
|
||||
pub model_type: String,
|
||||
/// The text-side hyperparameters. Everything we actually need.
|
||||
pub text_config: TextConfig,
|
||||
}
|
||||
|
||||
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
||||
/// but with the extras Qwen3-Next adds (`attn_output_gate`,
|
||||
/// `layer_types`, `full_attention_interval`, larger `head_dim`).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TextConfig {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
/// Nested RoPE settings. Qwen3-Next puts `rope_theta` and
|
||||
/// `partial_rotary_factor` inside this block rather than at the
|
||||
/// top level — important because the partial rotary means only
|
||||
/// `head_dim * partial_rotary_factor` dims get RoPE applied (the
|
||||
/// rest pass through unchanged).
|
||||
pub rope_parameters: RopeParameters,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default)]
|
||||
pub tie_word_embeddings: bool,
|
||||
|
||||
/// New in Qwen3-Next: a sigmoid gate multiplied into the attention
|
||||
/// output before the o_proj. The Python reference applies it
|
||||
/// pointwise after softmax+matmul.
|
||||
#[serde(default)]
|
||||
pub attn_output_gate: bool,
|
||||
|
||||
/// One entry per decoder layer; values are `"full_attention"` or
|
||||
/// `"linear_attention"`. Length must equal `num_hidden_layers`.
|
||||
/// `full_attention_interval` is a derived hint (every 4th layer
|
||||
/// by default) — `layer_types` is authoritative.
|
||||
#[serde(default)]
|
||||
pub layer_types: Vec<String>,
|
||||
|
||||
/// Hint for the layer-type pattern (defaults to 4). Kept for
|
||||
/// logging / validation; the forward dispatches on `layer_types`.
|
||||
#[serde(default)]
|
||||
pub full_attention_interval: Option<usize>,
|
||||
|
||||
/// Hidden activation (`"silu"` for Qwen3-Next). Used by the MLP
|
||||
/// and the linear-attention conv1d.
|
||||
#[serde(default = "default_hidden_act")]
|
||||
pub hidden_act: String,
|
||||
|
||||
// --- Gated DeltaNet (linear-attention) hyperparams -----------------
|
||||
/// Per-layer linear-attention V-head count (Qwen3.6-27B: 48).
|
||||
/// More V-heads than K-heads is fine — query/key get
|
||||
/// `repeat_interleave`'d to match before the delta rule.
|
||||
#[serde(default)]
|
||||
pub linear_num_value_heads: usize,
|
||||
/// Per-layer linear-attention K-head count (Qwen3.6-27B: 16).
|
||||
#[serde(default)]
|
||||
pub linear_num_key_heads: usize,
|
||||
/// Per-head key dimension for the linear-attention path
|
||||
/// (Qwen3.6-27B: 128). Separate from `head_dim` which the
|
||||
/// full-attention layers use.
|
||||
#[serde(default)]
|
||||
pub linear_key_head_dim: usize,
|
||||
/// Per-head value dimension for the linear-attention path
|
||||
/// (Qwen3.6-27B: 128).
|
||||
#[serde(default)]
|
||||
pub linear_value_head_dim: usize,
|
||||
/// Causal Conv1d kernel size used before the delta rule
|
||||
/// (Qwen3.6-27B: 4).
|
||||
#[serde(default)]
|
||||
pub linear_conv_kernel_dim: usize,
|
||||
}
|
||||
|
||||
fn default_hidden_act() -> String {
|
||||
"silu".into()
|
||||
}
|
||||
|
||||
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
|
||||
/// `mrope_section` and `mrope_interleaved` are accepted via the
|
||||
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat
|
||||
/// MRoPE as plain RoPE for text-only inference (the three position
|
||||
/// grids carry identical ids when there's no vision input, so the
|
||||
/// interleaving is a no-op).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RopeParameters {
|
||||
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
|
||||
#[serde(default = "default_rope_theta")]
|
||||
pub rope_theta: f64,
|
||||
/// Fraction of `head_dim` that gets the rotation applied. The
|
||||
/// remaining `head_dim * (1 - partial_rotary_factor)` dims pass
|
||||
/// through unchanged. Qwen3.6 / Qwen3.5: 0.25.
|
||||
#[serde(default = "default_partial_rotary_factor")]
|
||||
pub partial_rotary_factor: f32,
|
||||
/// `"default"` for the standard inv_freq RoPE; other values (e.g.
|
||||
/// `"linear"`, `"dynamic"`) are upstream-supported but not yet
|
||||
/// implemented here.
|
||||
#[serde(default)]
|
||||
pub rope_type: Option<String>,
|
||||
}
|
||||
|
||||
fn default_rope_theta() -> f64 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
fn default_partial_rotary_factor() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
/// Qwen3-Next base transformer (embedding + decoder stack + final
|
||||
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
|
||||
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
|
||||
/// loaded handle.
|
||||
pub struct Qwen3_5Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<Qwen3_5DecoderLayer>,
|
||||
norm: Qwen3_5RmsNorm,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Qwen3_5Model {
|
||||
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||
let dtype = vb.dtype();
|
||||
let device = vb.device().clone();
|
||||
|
||||
// Qwen3-Next is a multimodal architecture whose text core lives
|
||||
// under `model.language_model.*` — sibling to `model.visual.*`
|
||||
// (the vision tower) and to top-level `lm_head` / `mtp.*`.
|
||||
// Every text-side tensor in the safetensors files is under
|
||||
// this prefix; we ignore the vision and MTP weights for
|
||||
// language-model inference.
|
||||
let text_vb = vb.pp("model.language_model");
|
||||
|
||||
let embed_vb = text_vb.pp("embed_tokens");
|
||||
let embed_weight = embed_vb
|
||||
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
|
||||
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||
|
||||
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||
|
||||
if cfg.layer_types.len() != cfg.num_hidden_layers {
|
||||
anyhow::bail!(
|
||||
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
|
||||
got {}",
|
||||
cfg.num_hidden_layers,
|
||||
cfg.layer_types.len()
|
||||
);
|
||||
}
|
||||
|
||||
let vb_l = text_vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
layers.push(Qwen3_5DecoderLayer::load(
|
||||
cfg,
|
||||
rotary.clone(),
|
||||
i,
|
||||
&vb_l.pp(i),
|
||||
)?);
|
||||
}
|
||||
|
||||
let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embed_weight(&self) -> &Tensor {
|
||||
self.embed_tokens.embeddings()
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for l in &mut self.layers {
|
||||
l.clear_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let minf = f32::NEG_INFINITY;
|
||||
let mask: Vec<_> = (0..tgt)
|
||||
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
// Causal mask only needed for L > 1 prefill; full-attention
|
||||
// layers consume it via broadcast_add. Linear-attention layers
|
||||
// ignore the mask.
|
||||
let causal = if l == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.causal_mask(b, l, offset)?)
|
||||
};
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Qwen3_5ForCausalLM {
|
||||
base: Qwen3_5Model,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl Qwen3_5ForCausalLM {
|
||||
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
|
||||
let cfg = &config.text_config;
|
||||
let base = Qwen3_5Model::load(cfg, &vb)?;
|
||||
let lm_head = if cfg.tie_word_embeddings {
|
||||
Linear::new(base.embed_weight().clone(), None)
|
||||
} else {
|
||||
let weight = vb
|
||||
.pp("lm_head")
|
||||
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
||||
Linear::new(weight, None)
|
||||
};
|
||||
Ok(Self { base, lm_head })
|
||||
}
|
||||
|
||||
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
||||
/// the last position, shape `(B, 1, vocab_size)` — same contract
|
||||
/// as `qwen3::ModelForCausalLM::forward` so the harness's
|
||||
/// `squeeze_to_vocab` helper handles both uniformly.
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
let hidden = self.base.forward(input, offset)?;
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.base.clear_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Confirms we can deserialise the real upstream config shape.
|
||||
/// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to
|
||||
/// the fields the architecture cares about. Note `rope_theta` and
|
||||
/// `partial_rotary_factor` are nested under `rope_parameters` —
|
||||
/// Qwen3-Next does NOT have a top-level `rope_theta`.
|
||||
#[test]
|
||||
fn config_deserialises_the_real_qwen3_6_shape() {
|
||||
let raw = r#"{
|
||||
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||
"model_type": "qwen3_5",
|
||||
"image_token_id": 248056,
|
||||
"language_model_only": false,
|
||||
"text_config": {
|
||||
"vocab_size": 248064,
|
||||
"hidden_size": 5120,
|
||||
"intermediate_size": 17408,
|
||||
"num_hidden_layers": 64,
|
||||
"num_attention_heads": 64,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 256,
|
||||
"max_position_embeddings": 32768,
|
||||
"rope_parameters": {
|
||||
"mrope_interleaved": true,
|
||||
"mrope_section": [11, 11, 10],
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 10000000,
|
||||
"rope_type": "default"
|
||||
},
|
||||
"rms_norm_eps": 1e-6,
|
||||
"tie_word_embeddings": false,
|
||||
"attn_output_gate": true,
|
||||
"full_attention_interval": 4,
|
||||
"layer_types": [
|
||||
"linear_attention", "linear_attention",
|
||||
"linear_attention", "full_attention"
|
||||
]
|
||||
}
|
||||
}"#;
|
||||
let cfg: Config = serde_json::from_str(raw).expect("parse Qwen3.6 config");
|
||||
assert_eq!(cfg.model_type, "qwen3_5");
|
||||
assert_eq!(cfg.text_config.hidden_size, 5120);
|
||||
assert_eq!(cfg.text_config.head_dim, 256);
|
||||
assert!(cfg.text_config.attn_output_gate);
|
||||
assert_eq!(cfg.text_config.full_attention_interval, Some(4));
|
||||
assert_eq!(cfg.text_config.layer_types.len(), 4);
|
||||
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
|
||||
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
161
crates/neuron/src/harness/arch/qwen3_5/rmsnorm.rs
Normal file
161
crates/neuron/src/harness/arch/qwen3_5/rmsnorm.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
//! Norm primitives for Qwen3-Next.
|
||||
//!
|
||||
//! Two reasons we can't reuse `candle_nn::RmsNorm` directly:
|
||||
//!
|
||||
//! 1. **`(1.0 + weight)` scaling.** Qwen3-Next's `Qwen3_5RMSNorm`
|
||||
//! initialises `weight` to zeros and applies `(1.0 + weight)` to
|
||||
//! the normalised vector. `candle_nn::RmsNorm` applies `weight`
|
||||
//! directly. The two are equivalent only when the operator has
|
||||
//! pre-shifted the weights — the upstream checkpoints have not. See
|
||||
//! `huggingface/transformers#29402` for the upstream PR that
|
||||
//! introduced the `(1 + w)` form to recover from the zero-init.
|
||||
//!
|
||||
//! 2. **Gated variant.** The linear-attention layer post-normalises
|
||||
//! its output by an RMSNorm *gated* with a per-element SiLU on
|
||||
//! a sibling `z` projection — fused for numerical reasons (the
|
||||
//! norm's float32 promotion has to happen before the SiLU
|
||||
//! multiply). Not a single existing candle op.
|
||||
//!
|
||||
//! Both ops accept inputs in any compute dtype; promotion to f32 for
|
||||
//! the variance calculation matches the Python reference.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{D, Module, Tensor};
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
|
||||
/// L2-normalise along the last dim with a small epsilon. Matches the
|
||||
/// `l2norm` helper in `transformers/models/qwen3_5/modeling_qwen3_5.py`
|
||||
/// — `x * rsqrt(sum(x*x) + eps)`. The linear-attention path uses this
|
||||
/// on Q and K before the delta rule when
|
||||
/// `use_qk_l2norm_in_kernel=True` (which Qwen3-Next always sets).
|
||||
pub fn l2norm(x: &Tensor, eps: f32) -> candle_core::Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||
let sq = x_f32.sqr()?;
|
||||
let sum = sq.sum_keepdim(D::Minus1)?;
|
||||
let inv = (sum + eps as f64)?.sqrt()?.recip()?;
|
||||
x_f32.broadcast_mul(&inv)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Qwen3-Next's RMSNorm. Stores the raw weight tensor; forward applies
|
||||
/// `(1.0 + weight) * x_normed`.
|
||||
pub struct Qwen3_5RmsNorm {
|
||||
weight: Tensor,
|
||||
eps: f32,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl Qwen3_5RmsNorm {
|
||||
/// Load `weight` from the ShardedVarBuilder. `vb` should already be
|
||||
/// `.pp(...)`-ed to the norm's tensor prefix.
|
||||
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get(size, "weight")
|
||||
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
eps: eps as f32,
|
||||
size,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Qwen3_5RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
|
||||
// Promote weight to f32 and shift by 1.0 *before* multiplying.
|
||||
// Doing the (1 + w) operation in fp16 lands at -inf for the
|
||||
// bottom-of-range weights at load time.
|
||||
let w_f32 = self.weight.to_dtype(candle_core::DType::F32)?;
|
||||
let scale = (w_f32 + 1.0_f64)?;
|
||||
normed.broadcast_mul(&scale)?.to_dtype(dtype)
|
||||
}
|
||||
}
|
||||
|
||||
/// Gated RMSNorm used at the tail of `Qwen3_5GatedDeltaNet`. Equivalent
|
||||
/// to `x_normed * weight * silu(gate)` but with both the norm and the
|
||||
/// gate evaluated in float32 to avoid mid-pipeline underflow.
|
||||
///
|
||||
/// Note: unlike `Qwen3_5RmsNorm`, this variant matches the Python
|
||||
/// reference's `Qwen3_5RMSNormGated` which uses `weight` directly (not
|
||||
/// `1.0 + weight`).
|
||||
pub struct Qwen3_5RmsNormGated {
|
||||
weight: Tensor,
|
||||
eps: f32,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl Qwen3_5RmsNormGated {
|
||||
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get(size, "weight")
|
||||
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
eps: eps as f32,
|
||||
size,
|
||||
})
|
||||
}
|
||||
|
||||
/// Direct constructor — used by unit tests that build a layer
|
||||
/// without going through a VarBuilder.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
|
||||
let size = weight.dims()[0];
|
||||
Self {
|
||||
weight,
|
||||
eps: eps as f32,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
/// `x` and `gate` share the same last-dim shape (`size`).
|
||||
pub fn forward(&self, x: &Tensor, gate: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
|
||||
let w = self.weight.to_dtype(candle_core::DType::F32)?;
|
||||
let out = normed.broadcast_mul(&w)?;
|
||||
// SiLU on the float32 gate, multiply back into the normed
|
||||
// tensor, then cast to the model dtype.
|
||||
let g = gate.to_dtype(candle_core::DType::F32)?;
|
||||
let silu_gate = candle_nn::ops::silu(&g)?;
|
||||
(out * silu_gate)?.to_dtype(dtype)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle_core::Device;
|
||||
|
||||
#[test]
|
||||
fn l2norm_matches_hand_calc() {
|
||||
let x = Tensor::new(&[3.0_f32, 4.0_f32], &Device::Cpu).unwrap();
|
||||
let out = l2norm(&x, 1e-6).unwrap();
|
||||
let v: Vec<f32> = out.to_vec1().unwrap();
|
||||
// |x| = 5, so x/|x| = [0.6, 0.8] (eps is tiny).
|
||||
assert!((v[0] - 0.6).abs() < 1e-4);
|
||||
assert!((v[1] - 0.8).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn l2norm_zero_vector_is_safe_via_epsilon() {
|
||||
let x = Tensor::new(&[0.0_f32, 0.0_f32], &Device::Cpu).unwrap();
|
||||
let out = l2norm(&x, 1e-6).unwrap();
|
||||
let v: Vec<f32> = out.to_vec1().unwrap();
|
||||
assert!(v.iter().all(|x| x.is_finite()));
|
||||
}
|
||||
}
|
||||
114
crates/neuron/src/harness/arch/qwen3_5/rope.rs
Normal file
114
crates/neuron/src/harness/arch/qwen3_5/rope.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
||||
//!
|
||||
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
|
||||
//! reference Python — three position grids interleaved per
|
||||
//! `mrope_section`. For text-only inference all three grids carry the
|
||||
//! same position ids and the interleave is a no-op, so this module
|
||||
//! implements the plain (non-mrope) flavour: the standard inv_freq
|
||||
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
|
||||
//!
|
||||
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
|
||||
//! head dim is negated and swapped into the first). The reference
|
||||
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
|
||||
//! `rope_slow` is the matching helper.
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
|
||||
use super::TextConfig;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
/// Number of dims at the head's leading edge that the rotation
|
||||
/// covers. The remaining `head_dim - rotary_dim` dims pass through
|
||||
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
|
||||
/// for `head_dim = 256` only 64 dims rotate.
|
||||
rotary_dim: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
||||
let head_dim = cfg.head_dim;
|
||||
let rope = &cfg.rope_parameters;
|
||||
let rotary_dim = (head_dim as f32 * rope.partial_rotary_factor) as usize;
|
||||
if !rotary_dim.is_multiple_of(2) {
|
||||
anyhow::bail!(
|
||||
"rotary_dim = head_dim * partial_rotary_factor = {head_dim} * {} = {rotary_dim} \
|
||||
must be even (cos/sin are paired)",
|
||||
rope.partial_rotary_factor
|
||||
);
|
||||
}
|
||||
if rotary_dim == 0 {
|
||||
anyhow::bail!(
|
||||
"rotary_dim = 0 (partial_rotary_factor = {} too small)",
|
||||
rope.partial_rotary_factor
|
||||
);
|
||||
}
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<f32> = (0..rotary_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
|
||||
.collect();
|
||||
let n = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||
rotary_dim,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply RoPE to q, k.
|
||||
///
|
||||
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
|
||||
/// into the cached cos/sin table — the position of the first token
|
||||
/// in the current step.
|
||||
///
|
||||
/// When `rotary_dim < head_dim` the rotation is applied only to the
|
||||
/// first `rotary_dim` dims of each head; the tail passes through
|
||||
/// unchanged (matches the reference Python's
|
||||
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`).
|
||||
pub fn apply(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let (_, _, seq_len, head_dim_in) = q.dims4()?;
|
||||
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
|
||||
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||
if self.rotary_dim == self.head_dim {
|
||||
// Full rotation.
|
||||
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
} else {
|
||||
// Partial rotation: narrow → rotate → cat the untouched tail.
|
||||
let tail = self.head_dim - self.rotary_dim;
|
||||
let q_rot = q
|
||||
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||
.contiguous()?;
|
||||
let q_pass = q.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||
let k_rot = k
|
||||
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||
.contiguous()?;
|
||||
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?;
|
||||
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?;
|
||||
let q_embed =
|
||||
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||
let k_embed =
|
||||
Tensor::cat(&[&k_rotated, &k_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
}
|
||||
4026
crates/neuron/src/harness/candle.rs
Normal file
4026
crates/neuron/src/harness/candle.rs
Normal file
File diff suppressed because it is too large
Load Diff
392
crates/neuron/src/harness/chat_template.rs
Normal file
392
crates/neuron/src/harness/chat_template.rs
Normal file
@@ -0,0 +1,392 @@
|
||||
//! Chat-template rendering for the model-supplied Jinja templates
|
||||
//! HuggingFace tokenizers ship in `tokenizer_config.json`.
|
||||
//!
|
||||
//! ## Background
|
||||
//!
|
||||
//! Every modern open-weight model bundles a `chat_template` field
|
||||
//! in its `tokenizer_config.json` — a Jinja2 template string that
|
||||
//! converts a sequence of `{role, content}` messages into the
|
||||
//! exact prompt the model was trained on. Examples:
|
||||
//!
|
||||
//! - Qwen3-Coder: `<|im_start|>{role}\n{content}<|im_end|>\n…`
|
||||
//! with conditional `enable_thinking` handling that injects an
|
||||
//! empty `<think>\n\n</think>` block when set false.
|
||||
//! - DeepSeek-R1: similar im_start framing with different special-
|
||||
//! token names.
|
||||
//! - Mistral / Magistral: a `[INST]` / `[/INST]` framing.
|
||||
//! - Claude / Llama: another shape again.
|
||||
//!
|
||||
//! Rendering the model's own template is the only way to get the
|
||||
//! *exact* prompt format the model was trained on plus the
|
||||
//! model-specific kwargs (`enable_thinking`, `tools`, …) without
|
||||
//! hardcoding per-model logic. The alternative — neuron's previous
|
||||
//! `format_qwen3_prompt` — was a hardcoded Qwen3 ChatML glue that
|
||||
//! ignored kwargs entirely.
|
||||
//!
|
||||
//! ## Scope
|
||||
//!
|
||||
//! This module is request-side only: it builds the prompt string
|
||||
//! the tokenizer ingests before inference. The reasoning- and
|
||||
//! tool-call-marker token routing (issues #6, #8) is response-side
|
||||
//! and stays in `wire::openai_chat` / the streaming inference
|
||||
//! loops.
|
||||
//!
|
||||
//! ## Fallback
|
||||
//!
|
||||
//! When the model's `tokenizer_config.json` is missing, doesn't
|
||||
//! parse, lacks a `chat_template`, or renders an error, the caller
|
||||
//! falls back to `format_qwen3_prompt`. The
|
||||
//! `NEURON_USE_CHAT_TEMPLATE=false` env var is a global kill
|
||||
//! switch — if a deploy goes sideways and the renderer is to
|
||||
//! blame, an operator can flip the env and restart neuron without
|
||||
//! shipping a new build.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use cortex_core::openai::{ChatMessage, MessageContent};
|
||||
use minijinja::Environment;
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
|
||||
/// Environment variable that, when set to `false`/`0`/`no`,
|
||||
/// forces every model to skip its `chat_template` and fall back
|
||||
/// to `format_qwen3_prompt`. Default (unset) is "use chat
|
||||
/// templates where available".
|
||||
pub const KILL_SWITCH_ENV: &str = "NEURON_USE_CHAT_TEMPLATE";
|
||||
|
||||
/// Read the global kill switch. `true` means chat templates are
|
||||
/// enabled; `false` forces the fallback path everywhere.
|
||||
pub fn chat_templates_enabled() -> bool {
|
||||
match std::env::var(KILL_SWITCH_ENV).ok().as_deref() {
|
||||
Some(s) => !matches!(
|
||||
s.trim().to_ascii_lowercase().as_str(),
|
||||
"false" | "0" | "no" | "off"
|
||||
),
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience: probe for `tokenizer_config.json` in the same
|
||||
/// directory the tokenizer was loaded from. Both files come from
|
||||
/// the same HuggingFace snapshot in the hf-hub cache, so the
|
||||
/// sibling path is reliable.
|
||||
pub fn load_chat_template_alongside(tokenizer_json_path: &Path) -> Option<String> {
|
||||
let parent = tokenizer_json_path.parent()?;
|
||||
let config_path = parent.join("tokenizer_config.json");
|
||||
load_chat_template_from(&config_path)
|
||||
}
|
||||
|
||||
/// Best-effort load of `chat_template` from a HuggingFace
|
||||
/// `tokenizer_config.json`. Returns `None` when the file is
|
||||
/// absent, doesn't parse, or lacks the `chat_template` field —
|
||||
/// in all of those cases the caller falls back to
|
||||
/// `format_qwen3_prompt`. Warnings are logged so an operator can
|
||||
/// see why the fallback fired.
|
||||
pub fn load_chat_template_from(path: &Path) -> Option<String> {
|
||||
let text = match std::fs::read_to_string(path) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"chat_template: tokenizer_config.json absent or unreadable; falling back"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let value: Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"chat_template: tokenizer_config.json failed to parse; falling back"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
// Some tokenizer_config.json files carry `chat_template` as an
|
||||
// array of `{name, template}` objects (multi-template models —
|
||||
// tool-use variant, default variant). For now we pick the first
|
||||
// entry; future iterations could honour a name hint.
|
||||
match value.get("chat_template") {
|
||||
Some(Value::String(s)) => Some(s.clone()),
|
||||
Some(Value::Array(arr)) => {
|
||||
for entry in arr {
|
||||
if let Some(t) = entry.get("template").and_then(|v| v.as_str()) {
|
||||
return Some(t.to_string());
|
||||
}
|
||||
}
|
||||
tracing::warn!(
|
||||
path = %path.display(),
|
||||
"chat_template: array form had no usable template entry; falling back"
|
||||
);
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Render the chat template into the prompt the model expects.
|
||||
///
|
||||
/// `template` is the raw Jinja string from `tokenizer_config.json`.
|
||||
/// `messages` is the conversation in order. `kwargs` is the
|
||||
/// `chat_template_kwargs` object the client supplied on the
|
||||
/// request (or `Value::Null` when absent). The function expands
|
||||
/// the kwargs into the Jinja context alongside the standard
|
||||
/// `messages` and `add_generation_prompt` variables HF templates
|
||||
/// expect.
|
||||
///
|
||||
/// `tools` is the request's `tools` array (or `Value::Null`).
|
||||
/// Some chat templates iterate it to emit native tool definitions
|
||||
/// (Qwen3-Coder's tool-use template, Mistral's [TOOL_DEFINITIONS]
|
||||
/// frame). We forward whatever the client sent without
|
||||
/// interpretation.
|
||||
pub fn render_chat_template(
|
||||
template: &str,
|
||||
messages: &[ChatMessage],
|
||||
tools: &Value,
|
||||
kwargs: &Value,
|
||||
) -> Result<String> {
|
||||
let mut env = Environment::new();
|
||||
// Compile the template against a fixed name so error messages
|
||||
// surface "chat_template" rather than `<template>`.
|
||||
env.add_template("chat_template", template)
|
||||
.context("compile chat_template")?;
|
||||
let tmpl = env.get_template("chat_template").unwrap();
|
||||
|
||||
// Convert our internal ChatMessage shape into the
|
||||
// `[{role, content}]` shape HF templates iterate. Text content
|
||||
// becomes a string; Parts becomes an array of content blocks.
|
||||
// The HF templates handle both shapes via `content is string`
|
||||
// checks or content-array iteration.
|
||||
let messages_json: Vec<Value> = messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let content_value = match &m.content {
|
||||
MessageContent::Text(s) => Value::String(s.clone()),
|
||||
MessageContent::Parts(parts) => Value::Array(parts.clone()),
|
||||
};
|
||||
let mut obj = serde_json::Map::new();
|
||||
obj.insert("role".into(), Value::String(m.role.clone()));
|
||||
obj.insert("content".into(), content_value);
|
||||
// Forward extras (e.g. tool_calls on assistant turns,
|
||||
// tool_call_id on tool result turns). HF templates that
|
||||
// need them read e.g. `message.tool_calls`.
|
||||
if let Value::Object(extras) = &m.extra {
|
||||
for (k, v) in extras {
|
||||
obj.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
Value::Object(obj)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build the kwargs context. Add base bindings the template
|
||||
// expects (`messages`, `add_generation_prompt`, `tools`) plus
|
||||
// anything the caller passed in `chat_template_kwargs`. Caller
|
||||
// kwargs override the defaults so `add_generation_prompt: false`
|
||||
// from the request actually wins.
|
||||
let mut ctx_map = serde_json::Map::new();
|
||||
ctx_map.insert("messages".into(), Value::Array(messages_json));
|
||||
ctx_map.insert("add_generation_prompt".into(), Value::Bool(true));
|
||||
if !tools.is_null() {
|
||||
ctx_map.insert("tools".into(), tools.clone());
|
||||
}
|
||||
if let Value::Object(kwargs_obj) = kwargs {
|
||||
for (k, v) in kwargs_obj {
|
||||
ctx_map.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
// `Template::render` takes any Serialize value; serde_json's
|
||||
// `Value` implements it natively, so we pass the assembled
|
||||
// context object directly without going through the
|
||||
// `context!` macro (which expects minijinja-native values).
|
||||
tmpl.render(Value::Object(ctx_map))
|
||||
.context("render chat_template")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn user_msg(text: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: "user".into(),
|
||||
content: MessageContent::Text(text.into()),
|
||||
extra: Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_msg(text: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: MessageContent::Text(text.into()),
|
||||
extra: Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimal Qwen3-style template — enough surface to confirm
|
||||
/// our renderer threads role + content correctly without
|
||||
/// loading a real model's tokenizer_config.json.
|
||||
const QWEN3_LIKE: &str = "{%- for message in messages -%}\
|
||||
<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n\
|
||||
{%- endfor -%}\
|
||||
{%- if add_generation_prompt -%}<|im_start|>assistant\n{%- endif -%}";
|
||||
|
||||
#[test]
|
||||
fn renders_basic_conversation() {
|
||||
let prompt = render_chat_template(
|
||||
QWEN3_LIKE,
|
||||
&[user_msg("hello"), assistant_msg("hi"), user_msg("bye")],
|
||||
&Value::Null,
|
||||
&Value::Null,
|
||||
)
|
||||
.unwrap();
|
||||
// Structural assertions — the exact whitespace produced
|
||||
// by a given template is a Jinja-trim concern that varies
|
||||
// per real chat_template. What matters is that every
|
||||
// turn's role + content thread through in order, and that
|
||||
// the generation cue lands at the end.
|
||||
assert!(
|
||||
prompt.contains("<|im_start|>user\nhello<|im_end|>"),
|
||||
"first user turn missing: {prompt}"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("<|im_start|>assistant\nhi<|im_end|>"),
|
||||
"assistant turn missing: {prompt}"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("<|im_start|>user\nbye<|im_end|>"),
|
||||
"second user turn missing: {prompt}"
|
||||
);
|
||||
assert!(
|
||||
prompt.ends_with("<|im_start|>assistant")
|
||||
|| prompt.ends_with("<|im_start|>assistant\n"),
|
||||
"generation cue missing at end: {prompt}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kwargs_are_threaded_into_template_context() {
|
||||
// Replica of Qwen3's enable_thinking branch in
|
||||
// simplified form. When the kwarg is false, the model's
|
||||
// template injects an empty `<think>...</think>` block
|
||||
// before the generation cue — pre-filling the model's
|
||||
// reasoning slot with "no thinking" so the model emits
|
||||
// the answer directly.
|
||||
let template = "{%- if enable_thinking is defined and enable_thinking is false -%}\
|
||||
NO_THINK\
|
||||
{%- else -%}\
|
||||
THINK_OK\
|
||||
{%- endif -%}";
|
||||
let r_disabled = render_chat_template(
|
||||
template,
|
||||
&[],
|
||||
&Value::Null,
|
||||
&json!({ "enable_thinking": false }),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(r_disabled, "NO_THINK");
|
||||
let r_default = render_chat_template(template, &[], &Value::Null, &Value::Null).unwrap();
|
||||
assert_eq!(r_default, "THINK_OK");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_template_field_returns_none() {
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-missing-field.json");
|
||||
std::fs::write(&tmp, r#"{"some_other_field": 1}"#).unwrap();
|
||||
assert!(load_chat_template_from(&tmp).is_none());
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_template_from_string_field() {
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-string.json");
|
||||
std::fs::write(
|
||||
&tmp,
|
||||
r#"{"chat_template": "hello {{ messages[0].content }}"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
let t = load_chat_template_from(&tmp).expect("template loaded");
|
||||
assert!(t.contains("messages[0].content"));
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_template_from_array_form() {
|
||||
// Some HF models ship `chat_template` as `[{name, template}, ...]`.
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-array.json");
|
||||
std::fs::write(
|
||||
&tmp,
|
||||
r#"{"chat_template": [{"name": "default", "template": "ARR"}]}"#,
|
||||
)
|
||||
.unwrap();
|
||||
let t = load_chat_template_from(&tmp).expect("template loaded");
|
||||
assert_eq!(t, "ARR");
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_file_returns_none_quietly() {
|
||||
let absent = std::path::PathBuf::from("/definitely/not/a/real/path.json");
|
||||
assert!(load_chat_template_from(&absent).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unparseable_returns_none() {
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-garbage.json");
|
||||
std::fs::write(&tmp, b"{not valid json").unwrap();
|
||||
assert!(load_chat_template_from(&tmp).is_none());
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kill_switch_recognises_truthy_falsy_values() {
|
||||
// Test against the actual env var so callers see the
|
||||
// same behaviour as production. Serialise via a
|
||||
// mutex — see path_util.rs for the pattern.
|
||||
use std::sync::Mutex;
|
||||
static LOCK: Mutex<()> = Mutex::new(());
|
||||
let _g = LOCK.lock().unwrap();
|
||||
let prior = std::env::var(KILL_SWITCH_ENV).ok();
|
||||
unsafe {
|
||||
std::env::remove_var(KILL_SWITCH_ENV);
|
||||
}
|
||||
assert!(chat_templates_enabled());
|
||||
for value in ["false", "0", "no", "off", "FALSE", " no "] {
|
||||
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
|
||||
assert!(!chat_templates_enabled(), "value {value:?} should disable");
|
||||
}
|
||||
for value in ["true", "1", "yes", ""] {
|
||||
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
|
||||
assert!(chat_templates_enabled(), "value {value:?} should enable");
|
||||
}
|
||||
unsafe {
|
||||
match prior {
|
||||
Some(p) => std::env::set_var(KILL_SWITCH_ENV, p),
|
||||
None => std::env::remove_var(KILL_SWITCH_ENV),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_extras_thread_through_for_tool_calls() {
|
||||
// HF templates read assistant.tool_calls and tool
|
||||
// turns' tool_call_id. Confirm our extras flatten into
|
||||
// the message object the template iterates.
|
||||
let mut extras = serde_json::Map::new();
|
||||
extras.insert(
|
||||
"tool_calls".into(),
|
||||
json!([{"id": "t1", "function": {"name": "x", "arguments": "{}"}}]),
|
||||
);
|
||||
let msg = ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: MessageContent::Text(String::new()),
|
||||
extra: Value::Object(extras),
|
||||
};
|
||||
let template = "{{ messages[0].tool_calls[0].id }}";
|
||||
let rendered = render_chat_template(template, &[msg], &Value::Null, &Value::Null).unwrap();
|
||||
assert_eq!(rendered, "t1");
|
||||
}
|
||||
}
|
||||
810
crates/neuron/src/harness/device_worker/dispatch.rs
Normal file
810
crates/neuron/src/harness/device_worker/dispatch.rs
Normal file
@@ -0,0 +1,810 @@
|
||||
//! Synchronous dispatch loop running on the device worker thread.
|
||||
//!
|
||||
//! `run()` is the thread's entry point. It binds the CUDA context for
|
||||
//! its device on startup, then pulls `Job`s off the channel one at a
|
||||
//! time and runs the corresponding handler. The handlers are
|
||||
//! synchronous by design — the only async on this thread is the
|
||||
//! one-line `oneshot::Sender::send` call to ship the reply back, which
|
||||
//! is non-blocking.
|
||||
//!
|
||||
//! Phase 2 handles QueryVram, TransferIn, DropArch, ClearKv,
|
||||
//! ForwardLogits, Shutdown. Phase 3 will add the TP variants
|
||||
//! (NcclInit, NcclSanity, TpLoadShard, TpForward, TpClearKv) and the
|
||||
//! ARCH model state in this state slab will gain a companion
|
||||
//! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`.
|
||||
|
||||
use crate::harness::candle::ModelArch;
|
||||
#[cfg(feature = "cuda")]
|
||||
use crate::harness::device_worker::jobs::TpHandle;
|
||||
use crate::harness::device_worker::jobs::{ArchHandle, Job};
|
||||
#[cfg(feature = "cuda")]
|
||||
use crate::harness::tp::TpLeaderModel;
|
||||
use crate::harness::tp::nccl_state::NcclState;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc::Receiver;
|
||||
|
||||
/// Per-thread state owned by the worker. On CUDA builds the `Arc<CudaContext>`
|
||||
/// is created and bound at thread startup; on CPU builds the struct
|
||||
/// is mostly empty.
|
||||
struct DeviceWorkerState {
|
||||
#[allow(dead_code)]
|
||||
device_index: u32,
|
||||
/// Candle `Device` constructed at startup. Used by handlers (e.g.
|
||||
/// `ForwardLogits`) to build input tensors against the right
|
||||
/// device. Falls back to `Device::Cpu` if CUDA init fails.
|
||||
device: candle_core::Device,
|
||||
/// Boxed `ModelArch` slab. Indexed by an opaque `ArchHandle` minted
|
||||
/// by `TransferIn`. The Box means the entry's address is stable
|
||||
/// across HashMap rehashes (relevant only when we later hand out
|
||||
/// `&mut ModelArch` references — for Phase 2 every handler runs
|
||||
/// `&mut` via `get_mut`, no long-lived borrows).
|
||||
models: HashMap<ArchHandle, Box<ModelArch>>,
|
||||
/// Counter for minting fresh `ArchHandle`s. Each `TransferIn`
|
||||
/// increments and returns the new value. Wraps at u64::MAX after
|
||||
/// ~10^19 model loads — not a practical concern.
|
||||
next_handle: u64,
|
||||
/// Leader's NCCL state. Populated by `Job::NcclInit`; the
|
||||
/// underlying `Comm`'s libnccl handle lives bound to this thread
|
||||
/// for its entire lifetime. Subprocess workers maintain their own
|
||||
/// `NcclState` in their own processes — that's not visible from
|
||||
/// here.
|
||||
#[allow(dead_code)] // Read only via methods on NcclState
|
||||
nccl: NcclState,
|
||||
/// TP leader model slab. Same lifecycle as `models`; separate
|
||||
/// namespace so `ArchHandle` and `TpHandle` can't collide.
|
||||
#[cfg(feature = "cuda")]
|
||||
tp_models: HashMap<TpHandle, Box<TpLeaderModel>>,
|
||||
/// Counter for minting fresh `TpHandle`s.
|
||||
#[cfg(feature = "cuda")]
|
||||
next_tp_handle: u64,
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(dead_code)]
|
||||
/// `None` only if `CudaContext::new()` failed — in that case the
|
||||
/// thread still runs so the handle's lifecycle stays uniform, but
|
||||
/// every job that touches CUDA falls through to a zero reply with
|
||||
/// a log warning.
|
||||
ctx: Option<Arc<candle_core::cuda::cudarc::driver::CudaContext>>,
|
||||
}
|
||||
|
||||
/// Worker thread entry point. Runs until `Job::Shutdown` arrives or
|
||||
/// the channel sender is dropped (which happens when the last
|
||||
/// `DeviceWorkerHandle` `Arc` is dropped without an explicit
|
||||
/// `shutdown()`).
|
||||
pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool>) {
|
||||
let mut state = init_state(device_index);
|
||||
tracing::info!(device_index, "device worker started");
|
||||
|
||||
while let Ok(job) = rx.recv() {
|
||||
// Shutdown is processed unconditionally so a poisoned worker
|
||||
// still exits when asked. Matching by reference first so we
|
||||
// can fall through to the consume-match below.
|
||||
if matches!(&job, Job::Shutdown) {
|
||||
break;
|
||||
}
|
||||
if poisoned.load(Ordering::Acquire) {
|
||||
// Drain-only mode: reply with a poisoned error without
|
||||
// touching CUDA. Phase 1/2 never set the flag from the
|
||||
// dispatch loop itself (no driver errors classified yet),
|
||||
// but tests use `DeviceWorkerHandle::set_poisoned()` to
|
||||
// simulate this state.
|
||||
drain_poisoned(job, device_index);
|
||||
continue;
|
||||
}
|
||||
match job {
|
||||
Job::QueryVram { reply } => {
|
||||
let result = query_vram(&state);
|
||||
// If the caller dropped its receiver (request cancelled,
|
||||
// gateway timed out) the send fails — fine, we just
|
||||
// discard the reply.
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::LoadGguf {
|
||||
gguf_path,
|
||||
model_id,
|
||||
reply,
|
||||
} => {
|
||||
let result = load_gguf_inner(&state.device, &gguf_path, &model_id)
|
||||
.map(|arch| insert_arch(&mut state, Box::new(arch)));
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::LoadDense {
|
||||
config_path,
|
||||
safetensors_paths,
|
||||
model_id,
|
||||
reply,
|
||||
} => {
|
||||
let result =
|
||||
load_dense_inner(&state.device, &config_path, &safetensors_paths, &model_id)
|
||||
.map(|arch| insert_arch(&mut state, Box::new(arch)));
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::DropArch { handle, reply } => {
|
||||
let removed = state.models.remove(&handle);
|
||||
let was_present = removed.is_some();
|
||||
// Explicit drop on this thread — runs the Box<ModelArch>
|
||||
// Drop with the CUDA context bound here, which frees
|
||||
// all device tensors on the right context. The Drop is
|
||||
// implicit on the `removed` value going out of scope at
|
||||
// the end of the arm; calling drop() explicitly just
|
||||
// makes the intent visible.
|
||||
drop(removed);
|
||||
tracing::debug!(
|
||||
device_index,
|
||||
handle = handle.0,
|
||||
was_present,
|
||||
slab_size = state.models.len(),
|
||||
"device worker: model dropped"
|
||||
);
|
||||
let _ = reply.send(());
|
||||
}
|
||||
Job::ClearKv { handle, reply } => {
|
||||
let result = match state.models.get_mut(&handle) {
|
||||
Some(arch) => arch.clear_kv_cache(),
|
||||
None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)),
|
||||
};
|
||||
if result.is_ok() {
|
||||
trim_device_pool(&state);
|
||||
}
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::ForwardLogits {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
reply,
|
||||
} => {
|
||||
let result = forward_logits(&mut state, handle, &tokens, offset);
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
Job::NcclInit {
|
||||
cfg,
|
||||
comm_id_hex,
|
||||
reply,
|
||||
} => {
|
||||
let resp = state.nccl.init(cfg, &comm_id_hex);
|
||||
let _ = reply.send(resp);
|
||||
}
|
||||
Job::NcclSanity { reply } => {
|
||||
let resp = state.nccl.sanity_check();
|
||||
let _ = reply.send(resp);
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
Job::TpLoadShard {
|
||||
model_id,
|
||||
config_json,
|
||||
safetensors_paths,
|
||||
dtype,
|
||||
quant,
|
||||
world_size,
|
||||
reply,
|
||||
} => {
|
||||
let result = tp_load_shard_inner(
|
||||
&mut state,
|
||||
&model_id,
|
||||
&config_json,
|
||||
&safetensors_paths,
|
||||
dtype,
|
||||
quant.as_deref(),
|
||||
world_size,
|
||||
);
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
Job::DropTp { handle, reply } => {
|
||||
let removed = state.tp_models.remove(&handle);
|
||||
let was_present = removed.is_some();
|
||||
drop(removed);
|
||||
tracing::debug!(
|
||||
device_index,
|
||||
tp_handle = handle.0,
|
||||
was_present,
|
||||
slab_size = state.tp_models.len(),
|
||||
"device worker: TP model dropped"
|
||||
);
|
||||
let _ = reply.send(());
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
Job::TpClearKv { handle, reply } => {
|
||||
let result = match state.tp_models.get_mut(&handle) {
|
||||
Some(model) => {
|
||||
model.clear_kv_cache();
|
||||
Ok(())
|
||||
}
|
||||
None => Err(anyhow::anyhow!(
|
||||
"TpClearKv: no TP model for handle {}",
|
||||
handle.0
|
||||
)),
|
||||
};
|
||||
if result.is_ok() {
|
||||
trim_device_pool(&state);
|
||||
}
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
Job::TpForwardLogits {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
reply,
|
||||
} => {
|
||||
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
|
||||
let _ = reply.send(result);
|
||||
}
|
||||
// Handled by the matches!() check above; reaching here
|
||||
// means a Shutdown slipped past which is a bug.
|
||||
Job::Shutdown => unreachable!("Shutdown should break above"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
let tp_slab_size = state.tp_models.len();
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let tp_slab_size = 0_usize;
|
||||
tracing::info!(
|
||||
device_index,
|
||||
slab_size = state.models.len(),
|
||||
tp_slab_size,
|
||||
"device worker exiting; dropping remaining models"
|
||||
);
|
||||
// Drops every model in the slab on this thread before the function
|
||||
// returns. Critical for CUDA tensors: dropping on a thread that
|
||||
// doesn't have the context bound is UB. Phase 2 still runs Drop
|
||||
// via the slab going out of scope, which is correct as long as no
|
||||
// pre-poisoned state lurks in here — see the poisoned-mode
|
||||
// semantics in mod.rs for the Phase 3+ refinement.
|
||||
}
|
||||
|
||||
fn init_state(device_index: u32) -> DeviceWorkerState {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
use candle_core::cuda::cudarc::driver::CudaContext;
|
||||
// Construct a candle Device first — cudarc returns the
|
||||
// primary context for this index on subsequent calls, so
|
||||
// CudaContext::new and Device::new_cuda end up sharing state.
|
||||
let (device, ctx) = match candle_core::Device::new_cuda(device_index as usize) {
|
||||
Ok(device) => match CudaContext::new(device_index as usize) {
|
||||
Ok(ctx) => {
|
||||
if let Err(e) = ctx.bind_to_thread() {
|
||||
tracing::warn!(
|
||||
device_index,
|
||||
error = ?e,
|
||||
"device worker: bind_to_thread failed; \
|
||||
operations will still rebind per-call"
|
||||
);
|
||||
} else {
|
||||
tracing::info!(device_index, "device worker bound CUDA context");
|
||||
}
|
||||
(device, Some(ctx))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
device_index,
|
||||
error = ?e,
|
||||
"device worker: CudaContext::new failed; \
|
||||
vram queries will return (0, 0), forward will error"
|
||||
);
|
||||
(device, None)
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
device_index,
|
||||
error = %e,
|
||||
"device worker: Device::new_cuda failed; falling back to CPU device"
|
||||
);
|
||||
(candle_core::Device::Cpu, None)
|
||||
}
|
||||
};
|
||||
DeviceWorkerState {
|
||||
device_index,
|
||||
device,
|
||||
models: HashMap::new(),
|
||||
next_handle: 1,
|
||||
nccl: NcclState::new(),
|
||||
tp_models: HashMap::new(),
|
||||
next_tp_handle: 1,
|
||||
ctx,
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
DeviceWorkerState {
|
||||
device_index,
|
||||
device: candle_core::Device::Cpu,
|
||||
models: HashMap::new(),
|
||||
next_handle: 1,
|
||||
nccl: NcclState::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn query_vram(state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
||||
use candle_core::cuda::cudarc::driver::result;
|
||||
if state.ctx.is_none() {
|
||||
return Ok((0, 0));
|
||||
}
|
||||
// The context was bound in init_state. cudarc's `mem_get_info`
|
||||
// reads from the current context on the calling thread; since we
|
||||
// bound on startup and we never spawn child threads from this
|
||||
// worker, the binding holds.
|
||||
match result::mem_get_info() {
|
||||
Ok((free, total)) => Ok((
|
||||
(free / (1024 * 1024)) as u64,
|
||||
(total / (1024 * 1024)) as u64,
|
||||
)),
|
||||
Err(e) => Err(anyhow::anyhow!("mem_get_info: {e:?}")),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
|
||||
Ok((0, 0))
|
||||
}
|
||||
|
||||
/// Force cudarc's stream-ordered memory pool to release every block it
|
||||
/// is holding back to the system. After `ConcatKvCache::reset()` drops
|
||||
/// its tensors, the underlying `CudaSlice::drop` calls `cuMemFreeAsync`,
|
||||
/// which returns the blocks to the device's default mempool but not to
|
||||
/// the OS — `mem_get_info` still reports them as used. The next
|
||||
/// request's prefill then sees a falsely-small free pool and either
|
||||
/// OOMs or trips cuBLAS into `CUBLAS_STATUS_INTERNAL_ERROR`.
|
||||
///
|
||||
/// Calling `cuMemPoolTrimTo(pool, 0)` after each `clear_kv_cache`
|
||||
/// returns those blocks. We synchronize first so any pending
|
||||
/// `cuMemFreeAsync` operations have settled. Failures are non-fatal:
|
||||
/// the pool may not exist on legacy drivers, or a transient driver
|
||||
/// error may prevent the trim — neither breaks correctness, the next
|
||||
/// request just sees a less-recovered free pool.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn trim_device_pool(state: &DeviceWorkerState) {
|
||||
use candle_core::cuda::cudarc::driver::result::{device, mem_pool};
|
||||
let Some(ctx) = state.ctx.as_ref() else {
|
||||
return;
|
||||
};
|
||||
let (before_free, _) = match query_vram(state) {
|
||||
Ok(v) => v,
|
||||
Err(_) => (0, 0),
|
||||
};
|
||||
if let Err(e) = ctx.synchronize() {
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
error = ?e,
|
||||
"trim_device_pool: synchronize failed; skipping trim"
|
||||
);
|
||||
return;
|
||||
}
|
||||
let dev = ctx.cu_device();
|
||||
let pool = match unsafe { device::get_default_mem_pool(dev) } {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
error = ?e,
|
||||
"trim_device_pool: get_default_mem_pool failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
if let Err(e) = unsafe { mem_pool::trim_to(pool, 0) } {
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
error = ?e,
|
||||
"trim_device_pool: cuMemPoolTrimTo failed"
|
||||
);
|
||||
return;
|
||||
}
|
||||
let (after_free, _) = match query_vram(state) {
|
||||
Ok(v) => v,
|
||||
Err(_) => (0, 0),
|
||||
};
|
||||
let freed_mb = after_free.saturating_sub(before_free);
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
before_free_mb = before_free,
|
||||
after_free_mb = after_free,
|
||||
freed_mb,
|
||||
"trim_device_pool: trimmed pool"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn trim_device_pool(_state: &DeviceWorkerState) {}
|
||||
|
||||
/// Insert a freshly-built `ModelArch` into the slab and mint a fresh
|
||||
/// `ArchHandle`. Used by both `LoadGguf` and `LoadDense` dispatch
|
||||
/// handlers — they differ only in *how* the arch is built; the
|
||||
/// post-construction bookkeeping is identical.
|
||||
fn insert_arch(state: &mut DeviceWorkerState, arch: Box<ModelArch>) -> ArchHandle {
|
||||
let handle = ArchHandle(state.next_handle);
|
||||
state.next_handle = state.next_handle.wrapping_add(1);
|
||||
state.models.insert(handle, arch);
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
handle = handle.0,
|
||||
slab_size = state.models.len(),
|
||||
"device worker: model inserted"
|
||||
);
|
||||
handle
|
||||
}
|
||||
|
||||
/// Load a GGUF (pre-quantized) model on the worker thread. Pulled
|
||||
/// verbatim from the spawn_blocking closure that used to live in
|
||||
/// `CandleHarness::load_arch_gguf`; the only change is that `device`
|
||||
/// is now `state.device` (the worker's permanently-bound device).
|
||||
fn load_gguf_inner(
|
||||
device: &candle_core::Device,
|
||||
gguf_path: &std::path::Path,
|
||||
model_id: &str,
|
||||
) -> anyhow::Result<ModelArch> {
|
||||
use anyhow::Context;
|
||||
use candle_core::DType;
|
||||
use candle_core::quantized::gguf_file;
|
||||
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
|
||||
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
|
||||
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
|
||||
|
||||
tracing::info!(model = %model_id, path = ?gguf_path, "loading GGUF");
|
||||
let mut file = std::fs::File::open(gguf_path).context("open GGUF file")?;
|
||||
let content =
|
||||
gguf_file::Content::read(&mut file).map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
|
||||
|
||||
let architecture = content
|
||||
.metadata
|
||||
.get("general.architecture")
|
||||
.and_then(|v| v.to_string().ok().cloned())
|
||||
.unwrap_or_default();
|
||||
tracing::info!(architecture = %architecture, "GGUF architecture");
|
||||
|
||||
// The `general.architecture` GGUF metadata key follows
|
||||
// llama.cpp conventions (lowercase, no underscores in some
|
||||
// cases) — `qwen3moe`, not `qwen3_moe`.
|
||||
match architecture.as_str() {
|
||||
"qwen3" => {
|
||||
let weights = QuantizedQwen3Weights::from_gguf(content, &mut file, device)
|
||||
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
|
||||
Ok(ModelArch::Qwen3Quantized(weights))
|
||||
}
|
||||
"qwen3moe" => {
|
||||
// GGUFQWenMoE takes an explicit compute dtype alongside
|
||||
// the device — F16 matches the GGUF weights' typical
|
||||
// accumulation precision and gives the best tokens/sec on
|
||||
// consumer cards.
|
||||
let weights = GGUFQWenMoE::from_gguf(content, &mut file, device, DType::F16)
|
||||
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
|
||||
Ok(ModelArch::Qwen3MoeQuantized(weights))
|
||||
}
|
||||
"llama" => {
|
||||
let weights = QuantizedLlamaWeights::from_gguf(content, &mut file, device)
|
||||
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
|
||||
Ok(ModelArch::LlamaQuantized(weights))
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"unsupported GGUF architecture '{other}'; quantized path supports \
|
||||
qwen3, qwen3moe, llama"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a dense safetensors model on the worker thread.
|
||||
fn load_dense_inner(
|
||||
device: &candle_core::Device,
|
||||
config_path: &std::path::Path,
|
||||
safetensors_paths: &[std::path::PathBuf],
|
||||
model_id: &str,
|
||||
) -> anyhow::Result<ModelArch> {
|
||||
use anyhow::Context;
|
||||
use candle_core::DType;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::llama as llama_dense;
|
||||
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
|
||||
|
||||
let cfg_text = std::fs::read_to_string(config_path).context("read config.json")?;
|
||||
crate::harness::candle::check_dense_config_supported(&cfg_text, model_id)?;
|
||||
// Peek at model_type to choose the family before the typed
|
||||
// deserialize — each family has its own Config.
|
||||
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("model_type"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
tracing::info!(
|
||||
model = %model_id,
|
||||
model_type = %model_type,
|
||||
shards = safetensors_paths.len(),
|
||||
"loading dense model from safetensors"
|
||||
);
|
||||
|
||||
// bf16 is the canonical distribution dtype for Qwen3 / Llama 3 /
|
||||
// Qwen3 MoE. CUDA on Ada+ has hardware bf16; Ampere has it too.
|
||||
// CPU emulates.
|
||||
let dtype = DType::BF16;
|
||||
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
|
||||
// mutation by another process while we hold the mapping is UB.
|
||||
// We trust the HF cache is immutable-by-design.
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(safetensors_paths, dtype, device)
|
||||
.context("build VarBuilder over safetensors")?
|
||||
};
|
||||
|
||||
match model_type.as_str() {
|
||||
"qwen3" => {
|
||||
let cfg: qwen3_dense::Config =
|
||||
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
|
||||
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
|
||||
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
|
||||
Ok(ModelArch::Qwen3Dense(model))
|
||||
}
|
||||
"qwen3_moe" => {
|
||||
let cfg: qwen3_moe_dense::Config =
|
||||
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
|
||||
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
|
||||
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
|
||||
Ok(ModelArch::Qwen3MoeDense(model))
|
||||
}
|
||||
"llama" => {
|
||||
let cfg: llama_dense::LlamaConfig =
|
||||
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
|
||||
let config = cfg.into_config(false);
|
||||
let cache = llama_dense::Cache::new(true, dtype, &config, device)
|
||||
.context("build Llama Cache")?;
|
||||
let model = llama_dense::Llama::load(vb, &config)
|
||||
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
|
||||
Ok(ModelArch::LlamaDense(Box::new(
|
||||
crate::harness::candle::LlamaDense::from_parts(
|
||||
model,
|
||||
cache,
|
||||
config,
|
||||
dtype,
|
||||
device.clone(),
|
||||
),
|
||||
)))
|
||||
}
|
||||
"qwen3_5" => {
|
||||
let cfg: crate::harness::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
|
||||
.context("parse Qwen3-Next (qwen3_5) config.json")?;
|
||||
let sharded_vb = unsafe {
|
||||
candle_nn::var_builder::ShardedSafeTensors::var_builder(
|
||||
safetensors_paths,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
.context("build ShardedVarBuilder for Qwen3-Next")?
|
||||
};
|
||||
let model = crate::harness::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
|
||||
.context("build Qwen3-Next dense model")?;
|
||||
Ok(ModelArch::Qwen3_5Dense(model))
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"unrouted supported model_type '{other}' — \
|
||||
DENSE_SUPPORTED_MODEL_TYPES and load_dense_inner \
|
||||
must stay in sync"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the leader's TP shard on the worker thread. Reads the Comm
|
||||
/// directly from `state.nccl`; no cross-thread Arc<Comm> transfer.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn tp_load_shard_inner(
|
||||
state: &mut DeviceWorkerState,
|
||||
model_id: &str,
|
||||
config_json: &str,
|
||||
safetensors_paths: &[std::path::PathBuf],
|
||||
dtype: candle_core::DType,
|
||||
quant: Option<&str>,
|
||||
world_size: u32,
|
||||
) -> anyhow::Result<TpHandle> {
|
||||
use anyhow::Context;
|
||||
use candle_nn::var_builder::ShardedSafeTensors;
|
||||
|
||||
let comm = state.nccl.comm().ok_or_else(|| {
|
||||
anyhow::anyhow!("TpLoadShard: NcclState has no Comm; call NcclInit first")
|
||||
})?;
|
||||
|
||||
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("model_type"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// SAFETY: same invariant as the single-GPU dense path — the HF
|
||||
// cache files are treated as immutable while the mmap is held.
|
||||
let vb = unsafe {
|
||||
ShardedSafeTensors::var_builder(safetensors_paths, dtype, &state.device)
|
||||
.context("build ShardedVarBuilder over safetensors")?
|
||||
};
|
||||
let mmap = unsafe {
|
||||
candle_core::safetensors::MmapedSafetensors::multi(safetensors_paths)
|
||||
.context("build MmapedSafetensors for leader load")?
|
||||
};
|
||||
|
||||
let loaded = match model_type.as_str() {
|
||||
"qwen3" => {
|
||||
let cfg: crate::harness::tp::tp_qwen3::Config = serde_json::from_str(config_json)
|
||||
.context("parse Qwen3 Config JSON for leader load")?;
|
||||
TpLeaderModel::Qwen3(crate::harness::tp::tp_qwen3::TpQwen3ForCausalLM::load(
|
||||
&cfg, &vb, 0, world_size, comm,
|
||||
)?)
|
||||
}
|
||||
"qwen3_5" => {
|
||||
let cfg: crate::harness::tp::tp_qwen3_5::Config = serde_json::from_str(config_json)
|
||||
.context("parse Qwen3-Next Config JSON for leader load")?;
|
||||
let quant_dtype = crate::harness::tp::worker::parse_quant_string(quant)?;
|
||||
TpLeaderModel::Qwen3_5(crate::harness::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
|
||||
cfg,
|
||||
&vb,
|
||||
&mmap,
|
||||
0,
|
||||
world_size,
|
||||
comm,
|
||||
quant_dtype,
|
||||
)?)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
|
||||
),
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
rank = 0,
|
||||
model = %model_id,
|
||||
model_type = %model_type,
|
||||
"loaded TP shard (leader)"
|
||||
);
|
||||
|
||||
let handle = TpHandle(state.next_tp_handle);
|
||||
state.next_tp_handle = state.next_tp_handle.wrapping_add(1);
|
||||
state.tp_models.insert(handle, Box::new(loaded));
|
||||
tracing::debug!(
|
||||
device_index = state.device_index,
|
||||
tp_handle = handle.0,
|
||||
slab_size = state.tp_models.len(),
|
||||
"device worker: TP model inserted"
|
||||
);
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// TP-equivalent of [`forward_logits`]: looks up the leader's
|
||||
/// [`TpLeaderModel`] in the slab, runs its forward, copies the
|
||||
/// `[vocab]` logits to a CPU `Vec<f32>`. The leader's `Arc<Comm>`
|
||||
/// clones embedded in the TP layers' AllReduce ops fire from this
|
||||
/// thread — same thread that bound the CUDA context and that holds
|
||||
/// the `Comm` in `state.nccl`.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn tp_forward_logits(
|
||||
state: &mut DeviceWorkerState,
|
||||
handle: TpHandle,
|
||||
tokens: &[u32],
|
||||
offset: usize,
|
||||
) -> anyhow::Result<Vec<f32>> {
|
||||
use candle_core::{DType, Tensor};
|
||||
|
||||
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||
|
||||
let model = state
|
||||
.tp_models
|
||||
.get_mut(&handle)
|
||||
.ok_or_else(|| anyhow::anyhow!("TpForwardLogits: no model for handle {}", handle.0))?;
|
||||
|
||||
let logits = model.forward(&input, offset)?;
|
||||
// ForCausalLM forward returns [B, 1, V] after the trailing
|
||||
// .i((.., l - 1.., ..))?.apply(lm_head); squeeze both leading
|
||||
// singleton dims to a rank-1 [V] tensor for sampling.
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?;
|
||||
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||
let values = logits.to_vec1::<f32>()?;
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
||||
/// for sampling on the async caller. The model's `device()` (CUDA or
|
||||
/// CPU) determines where the kernel runs; this fn doesn't care.
|
||||
///
|
||||
/// On CUDA, the `to_dtype(F32).flatten_all().to_vec1::<f32>()` chain
|
||||
/// triggers the device → host copy. The copy runs synchronously on
|
||||
/// this worker thread; the bound context owns the source allocation
|
||||
/// so the transfer is straightforward.
|
||||
fn forward_logits(
|
||||
state: &mut DeviceWorkerState,
|
||||
handle: ArchHandle,
|
||||
tokens: &[u32],
|
||||
offset: usize,
|
||||
) -> anyhow::Result<Vec<f32>> {
|
||||
use candle_core::{DType, Tensor};
|
||||
|
||||
// Build the input tensor on the worker's own device. cudarc's
|
||||
// primary-context model means `Device::new_cuda(idx)` shares state
|
||||
// with the `CudaContext` we bound at startup, so this is the same
|
||||
// device the ModelArch was loaded against.
|
||||
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||
|
||||
let arch = state
|
||||
.models
|
||||
.get_mut(&handle)
|
||||
.ok_or_else(|| anyhow::anyhow!("ForwardLogits: no model for handle {}", handle.0))?;
|
||||
|
||||
let logits = arch.forward(&input, offset)?;
|
||||
// Copy to CPU f32. logits is already `[vocab]` (squeeze_to_vocab
|
||||
// inside ModelArch::forward). The to_dtype handles bf16/f16 →
|
||||
// f32 promotion for the sampler.
|
||||
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||
let values = logits.to_vec1::<f32>()?;
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
/// Reply to a job with the poisoned-worker error. Used when the worker
|
||||
/// has flipped into drain-only mode after a CUDA driver error.
|
||||
///
|
||||
/// `Job::Shutdown` is filtered before reaching this fn so the match
|
||||
/// only needs the data-carrying variants. As phases 2–4 add more
|
||||
/// variants the match here grows; every variant must reply with the
|
||||
/// poisoned error so callers never hang waiting for a worker that's
|
||||
/// no longer running CUDA.
|
||||
fn drain_poisoned(job: Job, device_index: u32) {
|
||||
let err = || anyhow::anyhow!("device worker for device {device_index} is poisoned");
|
||||
match job {
|
||||
Job::QueryVram { reply } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::LoadGguf { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::LoadDense { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::DropArch { reply, .. } => {
|
||||
// Drop reply is `()` — no error path. Send the unit so the
|
||||
// caller's await resolves; the model handle is leaked in
|
||||
// the worker's slab, but the whole slab gets `mem::forget`
|
||||
// on shutdown anyway per the poisoned-thread design.
|
||||
let _ = reply.send(());
|
||||
}
|
||||
Job::ClearKv { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::ForwardLogits { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::NcclInit { reply, .. } => {
|
||||
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||
kind: "device_worker_poisoned".into(),
|
||||
message: format!("device worker {device_index} poisoned"),
|
||||
});
|
||||
}
|
||||
Job::NcclSanity { reply } => {
|
||||
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||
kind: "device_worker_poisoned".into(),
|
||||
message: format!("device worker {device_index} poisoned"),
|
||||
});
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
Job::TpLoadShard { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
Job::DropTp { reply, .. } => {
|
||||
let _ = reply.send(());
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
Job::TpClearKv { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
Job::TpForwardLogits { reply, .. } => {
|
||||
let _ = reply.send(Err(err()));
|
||||
}
|
||||
Job::Shutdown => {
|
||||
// Filtered by the matches!() guard in run(); reaching
|
||||
// here would be a logic error.
|
||||
unreachable!("Shutdown is filtered before drain_poisoned");
|
||||
}
|
||||
}
|
||||
}
|
||||
169
crates/neuron/src/harness/device_worker/jobs.rs
Normal file
169
crates/neuron/src/harness/device_worker/jobs.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
//! Job variants accepted by the per-device worker thread.
|
||||
//!
|
||||
//! Each variant carries the inputs the synchronous dispatch handler
|
||||
//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The
|
||||
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
|
||||
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
|
||||
|
||||
use anyhow::Result;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
/// Opaque handle to a `ModelArch` stored in the worker thread's state
|
||||
/// slab. Cheap to copy; `Send + Sync` so it crosses task boundaries
|
||||
/// freely. The actual `Box<ModelArch>` it points to is owned by the
|
||||
/// worker thread for the duration of the handle's lifetime — the only
|
||||
/// way to drop the model is to send `Job::DropArch { handle }` so the
|
||||
/// `Drop` impl runs on the thread with the bound CUDA context (the
|
||||
/// invariant the whole refactor exists to guarantee).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct ArchHandle(pub u64);
|
||||
|
||||
/// Opaque handle to a `TpLeaderModel` stored in the worker thread's
|
||||
/// state slab. Same shape as [`ArchHandle`] but in a separate
|
||||
/// namespace so the two slabs can coexist without ambiguity. Phase 3
|
||||
/// introduces it; Phase 4 may unify the two slabs after the TP forward
|
||||
/// path proves out.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct TpHandle(pub u64);
|
||||
|
||||
/// One unit of work for the device worker.
|
||||
///
|
||||
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
||||
/// single-GPU inference primitives: transfer-in a freshly-loaded
|
||||
/// `ModelArch`, drop it, clear its KV cache, and run one forward step
|
||||
/// returning CPU-side logits ready for sampling on the async caller.
|
||||
///
|
||||
/// Sampling stays on the async side intentionally. The worker copies
|
||||
/// logits to CPU (`Vec<f32>`) before reply, so the device-resident
|
||||
/// tensor never escapes the worker thread and the async caller's
|
||||
/// `LogitsProcessor::sample` runs entirely on the CPU candle backend
|
||||
/// — no incidental context binding on a tokio worker thread.
|
||||
pub enum Job {
|
||||
/// Query free / total VRAM on the device. Returns
|
||||
/// `(free_mb, total_mb)`. CPU builds and contexts that failed to
|
||||
/// initialise reply with `(0, 0)` — matches today's
|
||||
/// `device_vram_mb` sentinel so the log field values don't change.
|
||||
QueryVram {
|
||||
reply: oneshot::Sender<Result<(u64, u64)>>,
|
||||
},
|
||||
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||
/// thread. The dispatch handler opens the GGUF file, parses
|
||||
/// metadata, dispatches on `general.architecture`, and inserts
|
||||
/// the resulting `ModelArch` into the slab. Returns the fresh
|
||||
/// `ArchHandle`.
|
||||
LoadGguf {
|
||||
gguf_path: PathBuf,
|
||||
model_id: String,
|
||||
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||
},
|
||||
/// Load a dense safetensors single-GPU model on the worker
|
||||
/// thread. The dispatch handler reads `config.json`, dispatches on
|
||||
/// `model_type`, builds a `VarBuilder` over the mmap'd
|
||||
/// safetensors, and inserts the resulting `ModelArch`.
|
||||
LoadDense {
|
||||
config_path: PathBuf,
|
||||
safetensors_paths: Vec<PathBuf>,
|
||||
model_id: String,
|
||||
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||
},
|
||||
/// Remove the model from the slab and drop it. The `Drop` runs on
|
||||
/// the worker thread so CUDA tensors release their memory on the
|
||||
/// same context that allocated them.
|
||||
DropArch {
|
||||
handle: ArchHandle,
|
||||
reply: oneshot::Sender<()>,
|
||||
},
|
||||
/// Reset the KV cache for this model. Called at the start of every
|
||||
/// chat completion so a new request doesn't attend over the
|
||||
/// previous one's tokens.
|
||||
ClearKv {
|
||||
handle: ArchHandle,
|
||||
reply: oneshot::Sender<Result<()>>,
|
||||
},
|
||||
/// Run one forward step and copy the resulting `[vocab]` logits to
|
||||
/// CPU. The caller takes the returned `Vec<f32>`, wraps it in a
|
||||
/// CPU `Tensor`, and runs `apply_repeat_penalty` + sampling
|
||||
/// without touching the device context. `offset` is the KV-cache
|
||||
/// position before this step (0 for prefill, `prompt_len + i` for
|
||||
/// the i-th decode step).
|
||||
ForwardLogits {
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Initialize the leader's NCCL communicator. The worker's
|
||||
/// `NcclState` mints the `Comm` here so its underlying
|
||||
/// `ncclComm_t` and `CudaContext` live on the same thread as
|
||||
/// every later `Comm::all_reduce` call. Reply is the worker
|
||||
/// response shape used by the subprocess workers (`InitOk` on
|
||||
/// success, `Error` on failure) so the calling
|
||||
/// `WorkerPool::init_nccl` orchestration stays uniform.
|
||||
///
|
||||
/// Available on both cuda and no-cuda builds — the dispatch
|
||||
/// handler calls `NcclState::init` which has a no-cuda stub that
|
||||
/// replies with `cuda_feature_not_enabled`. Keeping the Job
|
||||
/// variant ungated lets `WorkerPool::init_nccl` stay uniform.
|
||||
NcclInit {
|
||||
cfg: crate::harness::tp::worker::WorkerConfig,
|
||||
comm_id_hex: String,
|
||||
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||
},
|
||||
/// Run NCCL's all_reduce sanity check on the leader's rank 0.
|
||||
/// Same response shape as `NcclInit`; also available on both
|
||||
/// builds via the no-cuda `NcclState::sanity_check` stub.
|
||||
NcclSanity {
|
||||
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||
},
|
||||
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||
/// handler reads `state.nccl.comm()` directly (no cross-thread
|
||||
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
|
||||
/// `TpLeaderModel` against that Comm. The model's embedded
|
||||
/// `Arc<Comm>` clones, `CudaContext`, and all per-rank CUDA
|
||||
/// tensors live on this thread for the model's lifetime.
|
||||
/// Inserts into the TP slab and returns the fresh `TpHandle`.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpLoadShard {
|
||||
model_id: String,
|
||||
config_json: String,
|
||||
safetensors_paths: Vec<PathBuf>,
|
||||
dtype: candle_core::DType,
|
||||
quant: Option<String>,
|
||||
world_size: u32,
|
||||
reply: oneshot::Sender<Result<TpHandle>>,
|
||||
},
|
||||
/// Drop the TP leader model on the worker thread. CUDA tensors
|
||||
/// and `Arc<Comm>` clones held inside the model release on the
|
||||
/// thread that allocated them.
|
||||
#[cfg(feature = "cuda")]
|
||||
DropTp {
|
||||
handle: TpHandle,
|
||||
reply: oneshot::Sender<()>,
|
||||
},
|
||||
/// Reset the leader's KV cache for a TP model. Mirrors `ClearKv`
|
||||
/// for single-GPU.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpClearKv {
|
||||
handle: TpHandle,
|
||||
reply: oneshot::Sender<Result<()>>,
|
||||
},
|
||||
/// Run one TP forward step on the leader's shard. Returns CPU-
|
||||
/// side logits as a `Vec<f32>` so the async caller can sample
|
||||
/// without holding a device tensor. The caller is also
|
||||
/// responsible for fan-out to subprocess ranks and drain — only
|
||||
/// the leader's forward moves into the worker thread.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpForwardLogits {
|
||||
handle: TpHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
||||
/// queued after this in the channel reply `Err` to their oneshot
|
||||
/// senders (the senders are dropped on the worker's exit, which
|
||||
/// the async-side `Receiver::await` maps to `WorkerError::Gone`).
|
||||
Shutdown,
|
||||
}
|
||||
592
crates/neuron/src/harness/device_worker/mod.rs
Normal file
592
crates/neuron/src/harness/device_worker/mod.rs
Normal file
@@ -0,0 +1,592 @@
|
||||
//! Per-device CUDA worker thread.
|
||||
//!
|
||||
//! One dedicated OS thread per CUDA device the leader uses. The thread
|
||||
//! binds the device's `CudaContext` once at startup and owns it for the
|
||||
//! daemon's lifetime; all GPU operations and VRAM queries for that
|
||||
//! device route through a `std::sync::mpsc` channel into this thread.
|
||||
//! Tensors never escape the thread alive — replies cross the channel
|
||||
//! as plain values (`u32` tokens, `(u64, u64)` mb numbers, `()`).
|
||||
//!
|
||||
//! Rationale, in order of weight:
|
||||
//!
|
||||
//! 1. **Context locality.** cudarc binds the CUDA context per OS thread
|
||||
//! via `cuCtxSetCurrent`. With `tokio::task::spawn_blocking`, the
|
||||
//! blocking thread chosen is arbitrary, so the context gets bound
|
||||
//! onto a different thread each time and `device_vram_mb()` from an
|
||||
//! async task binds it again on the *caller's* thread as a side
|
||||
//! effect. Pinning the context to one named thread ends that.
|
||||
//!
|
||||
//! 2. **Drop safety.** `cudarc::driver::CudaContext`, every `CudaSlice`
|
||||
//! inside a `Tensor`, and every `cudarc::nccl::Comm` call `cuMemFree`
|
||||
//! / `cuCtxDestroy` / `ncclCommDestroy` during `Drop`. These must
|
||||
//! run with the right context current. Owning everything in this
|
||||
//! thread's state slab and dropping it via `Job::DropArch` /
|
||||
//! `Job::Shutdown` is the only safe pattern.
|
||||
//!
|
||||
//! 3. **Poisoning blast radius.** When a CUDA driver error (illegal
|
||||
//! address, OOM cascade) makes the context unrecoverable, today the
|
||||
//! spawn_blocking thread carrying that bad state simply returns to
|
||||
//! tokio's pool — invisible. With the per-device thread, the
|
||||
//! poisoned flag lives on the thread itself; subsequent
|
||||
//! `submit()` calls fast-reject at the channel boundary with a
|
||||
//! clear "device worker is poisoned" error before any further CUDA
|
||||
//! work is attempted.
|
||||
//!
|
||||
//! The TP worker subprocesses (`harness/tp/worker.rs`) are already this
|
||||
//! pattern, just out-of-process. The in-process variant uses the same
|
||||
//! discipline for rank 0.
|
||||
//!
|
||||
//! Phase 1 of the refactor exposes only `Job::QueryVram` + `Job::Shutdown`.
|
||||
//! Forward, kv-cache clear, model load, and NCCL bring-up move in later
|
||||
//! phases. See `/home/grenade/.claude/plans/plan-the-per-device-worker-abstract-micali.md`.
|
||||
|
||||
pub mod dispatch;
|
||||
pub mod jobs;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc::{self, Sender};
|
||||
use std::thread::JoinHandle;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use jobs::TpHandle;
|
||||
pub use jobs::{ArchHandle, Job};
|
||||
|
||||
/// Errors returned by `DeviceWorkerHandle` submit methods.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum WorkerError {
|
||||
/// The worker's CUDA context was poisoned by an earlier driver
|
||||
/// error. The thread is still alive (dropping it would re-touch
|
||||
/// the broken context); it returns this error for every job
|
||||
/// submitted until the daemon is restarted.
|
||||
#[error(
|
||||
"device worker for device {device_index} is poisoned \
|
||||
(a prior CUDA driver error left the context unrecoverable); \
|
||||
restart the daemon to recover"
|
||||
)]
|
||||
Poisoned { device_index: u32 },
|
||||
/// The worker thread has exited (`Job::Shutdown` was processed or
|
||||
/// the thread panicked). Subsequent `submit()` calls fail here
|
||||
/// rather than blocking forever.
|
||||
#[error("device worker for device {device_index} is no longer running")]
|
||||
Gone { device_index: u32 },
|
||||
/// The dispatched job returned an `Err`. Forwarded verbatim.
|
||||
#[error(transparent)]
|
||||
Job(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
/// Shared handle to a per-device CUDA worker thread.
|
||||
///
|
||||
/// Cloning the `Arc` lets multiple `LoadedModel`s (and `TpLoadedModel`s)
|
||||
/// share the same worker — there's one worker per CUDA device index,
|
||||
/// not one per model.
|
||||
pub struct DeviceWorkerHandle {
|
||||
device_index: u32,
|
||||
tx: Sender<Job>,
|
||||
poisoned: Arc<AtomicBool>,
|
||||
/// `Mutex<Option<JoinHandle>>` so `shutdown()` can take the handle
|
||||
/// out without `&mut self` and so the inevitable `Drop` after
|
||||
/// `shutdown()` doesn't double-join. The mutex is uncontended in
|
||||
/// practice: only one caller ever takes the handle.
|
||||
join: std::sync::Mutex<Option<JoinHandle<()>>>,
|
||||
}
|
||||
|
||||
impl DeviceWorkerHandle {
|
||||
/// Spawn a new worker for the given CUDA device index.
|
||||
///
|
||||
/// The thread is named `cuda-dev-N` so it shows up legibly in
|
||||
/// `top -H`, `pidstat -t`, and gdb backtraces. On CUDA builds, the
|
||||
/// thread binds `CudaContext::new(N)` on startup; on CPU builds
|
||||
/// (`--no-default-features`) the thread runs without a context and
|
||||
/// every job that touches CUDA falls through to a zero return.
|
||||
pub fn spawn(device_index: u32) -> anyhow::Result<Arc<Self>> {
|
||||
let (tx, rx) = mpsc::channel::<Job>();
|
||||
let poisoned = Arc::new(AtomicBool::new(false));
|
||||
let poisoned_for_thread = Arc::clone(&poisoned);
|
||||
let join = std::thread::Builder::new()
|
||||
.name(format!("cuda-dev-{device_index}"))
|
||||
.spawn(move || {
|
||||
dispatch::run(device_index, rx, poisoned_for_thread);
|
||||
})?;
|
||||
Ok(Arc::new(Self {
|
||||
device_index,
|
||||
tx,
|
||||
poisoned,
|
||||
join: std::sync::Mutex::new(Some(join)),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn device_index(&self) -> u32 {
|
||||
self.device_index
|
||||
}
|
||||
|
||||
pub fn is_poisoned(&self) -> bool {
|
||||
self.poisoned.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// Mark the worker's context as poisoned. Future `submit()` calls
|
||||
/// short-circuit to `WorkerError::Poisoned` before sending. The
|
||||
/// dispatch loop also flips into drain-only mode when it sees this
|
||||
/// flag, so any jobs already in flight on the channel reply with
|
||||
/// the same error without touching CUDA.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn set_poisoned(&self) {
|
||||
self.poisoned.store(true, Ordering::Release);
|
||||
}
|
||||
|
||||
/// Send `Job::QueryVram`, await the worker's reply.
|
||||
///
|
||||
/// Returns `Ok((free_mb, total_mb))` on success, `Ok((0, 0))` on
|
||||
/// CPU builds or when the device lacks a bound context, or an
|
||||
/// error if the worker is poisoned, gone, or the query itself
|
||||
/// failed inside cudarc.
|
||||
pub async fn query_vram(&self) -> Result<(u64, u64), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::QueryVram { reply: reply_tx })
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||
/// thread. The hf-hub resolution happens on the async caller; the
|
||||
/// resolved local `gguf_path` plus the spec's model_id are sent
|
||||
/// into the worker which opens, parses, and constructs the
|
||||
/// `ModelArch` on the right thread.
|
||||
pub async fn load_gguf(
|
||||
&self,
|
||||
gguf_path: std::path::PathBuf,
|
||||
model_id: String,
|
||||
) -> Result<ArchHandle, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::LoadGguf {
|
||||
gguf_path,
|
||||
model_id,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a dense safetensors single-GPU model on the worker thread.
|
||||
pub async fn load_dense(
|
||||
&self,
|
||||
config_path: std::path::PathBuf,
|
||||
safetensors_paths: Vec<std::path::PathBuf>,
|
||||
model_id: String,
|
||||
) -> Result<ArchHandle, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::LoadDense {
|
||||
config_path,
|
||||
safetensors_paths,
|
||||
model_id,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell the worker to drop the `ModelArch` for `handle` on the
|
||||
/// worker thread (so CUDA tensors release on the right context).
|
||||
/// Returns `Ok(())` even if the handle wasn't in the slab — Drop
|
||||
/// is idempotent. Reports `Gone` if the worker isn't running.
|
||||
pub async fn drop_arch(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||
// Poisoning doesn't block DropArch — even on a poisoned
|
||||
// context we want callers to unblock and proceed with the
|
||||
// unload bookkeeping. The dispatch handler under poison just
|
||||
// replies `()` without touching the model (the actual Drop
|
||||
// happens via mem::forget at thread exit per the poison
|
||||
// protocol).
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::DropArch {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the KV cache for the model at `handle`. Called at the
|
||||
/// start of every chat completion so the new prompt doesn't
|
||||
/// attend over the previous request's tokens.
|
||||
pub async fn clear_kv_cache(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::ClearKv {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one forward step and return the resulting `[vocab]` logits
|
||||
/// as a CPU-side `Vec<f32>`. The caller then samples on a CPU
|
||||
/// candle Tensor without ever binding the device context on its
|
||||
/// tokio thread.
|
||||
pub async fn forward_logits(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::ForwardLogits {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialise the leader's NCCL communicator. The reply uses
|
||||
/// `WorkerResponse` (same shape subprocess workers use over stdio
|
||||
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
|
||||
/// subprocess responses uniformly. Available on no-cuda builds
|
||||
/// too — the dispatch handler calls the no-cuda `NcclState::init`
|
||||
/// stub which replies `cuda_feature_not_enabled`.
|
||||
pub async fn nccl_init(
|
||||
&self,
|
||||
cfg: crate::harness::tp::worker::WorkerConfig,
|
||||
comm_id_hex: String,
|
||||
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::NcclInit {
|
||||
cfg,
|
||||
comm_id_hex,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
reply_rx.await.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run an NCCL sanity all_reduce on the leader's rank 0.
|
||||
/// Available on no-cuda builds; replies with an error response.
|
||||
pub async fn nccl_sanity(
|
||||
&self,
|
||||
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::NcclSanity { reply: reply_tx })
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
reply_rx.await.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||
/// handler reads its own `NcclState`'s `Arc<Comm>` directly — no
|
||||
/// cross-thread Comm transfer — and builds the `TpLeaderModel`
|
||||
/// against it. Phase 4 replaces the Phase 3 Clone/TransferIn
|
||||
/// bridge with this single Job.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn tp_load_shard(
|
||||
&self,
|
||||
model_id: String,
|
||||
config_json: String,
|
||||
safetensors_paths: Vec<std::path::PathBuf>,
|
||||
dtype: candle_core::DType,
|
||||
quant: Option<String>,
|
||||
world_size: u32,
|
||||
) -> Result<TpHandle, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpLoadShard {
|
||||
model_id,
|
||||
config_json,
|
||||
safetensors_paths,
|
||||
dtype,
|
||||
quant,
|
||||
world_size,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop the TP model at `handle` on the worker thread.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn drop_tp(&self, handle: TpHandle) -> Result<(), WorkerError> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::DropTp {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the leader's KV cache for a TP model.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn tp_clear_kv(&self, handle: TpHandle) -> Result<(), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpClearKv {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one TP forward step on the leader's shard. Returns CPU-side
|
||||
/// logits as `Vec<f32>` ready for sampling. The caller is
|
||||
/// responsible for fan-out / drain of the subprocess workers
|
||||
/// concurrently with this call.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn tp_forward_logits(
|
||||
&self,
|
||||
handle: TpHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpForwardLogits {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
||||
/// twice is a no-op the second time.
|
||||
pub fn shutdown(&self) -> anyhow::Result<()> {
|
||||
// Best-effort send: if the channel is already closed (thread
|
||||
// exited after a prior shutdown or panic) the send fails and
|
||||
// we fall through to the join which returns the panic, if any.
|
||||
let _ = self.tx.send(Job::Shutdown);
|
||||
let join = self.join.lock().unwrap().take();
|
||||
if let Some(j) = join {
|
||||
j.join()
|
||||
.map_err(|_| anyhow::anyhow!("worker thread panicked during shutdown"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DeviceWorkerHandle {
|
||||
fn drop(&mut self) {
|
||||
// Best-effort: send Shutdown so the thread breaks its loop
|
||||
// and exits. We do NOT join here — Drop may run on a tokio
|
||||
// worker thread, and joining a thread that's still processing
|
||||
// the last job would block the runtime. The OS reaps the
|
||||
// thread on detach.
|
||||
let _ = self.tx.send(Job::Shutdown);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_query_vram_shutdown() {
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
// CPU build (the only one CI runs) returns (0, 0) by design;
|
||||
// a CUDA build with a real device would return real values.
|
||||
let result = handle.query_vram().await.expect("query ok");
|
||||
// We assert >= 0 — the field width matters more than the value.
|
||||
let _ = result.0;
|
||||
let _ = result.1;
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_is_named_correctly() {
|
||||
// The thread name lets `top -H` / pidstat / gdb show
|
||||
// `cuda-dev-N` instead of an opaque tokio worker name. Verify
|
||||
// by spawning and reading proc-self thread comms — but on
|
||||
// platforms without /proc, just confirm we don't crash.
|
||||
let handle = DeviceWorkerHandle::spawn(7).expect("spawn ok");
|
||||
// Round-trip a job to ensure the thread is alive and processing.
|
||||
handle.query_vram().await.expect("query ok");
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn submit_after_shutdown_returns_gone() {
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
// Channel closed; submit should map to Gone rather than block.
|
||||
let result = handle.query_vram().await;
|
||||
match result {
|
||||
Err(WorkerError::Gone { device_index: 0 }) => {}
|
||||
other => panic!("expected Gone, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poisoned_flag_short_circuits_submit() {
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
handle.set_poisoned();
|
||||
let result = handle.query_vram().await;
|
||||
match result {
|
||||
Err(WorkerError::Poisoned { device_index: 0 }) => {}
|
||||
other => panic!("expected Poisoned, got {other:?}"),
|
||||
}
|
||||
// The channel is still alive; shutdown should still succeed.
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_drains_pending_jobs() {
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
// Submit many concurrent jobs; they should all complete even
|
||||
// though a Shutdown is racing them.
|
||||
let mut futures = Vec::new();
|
||||
for _ in 0..16 {
|
||||
let h = Arc::clone(&handle);
|
||||
futures.push(tokio::spawn(async move { h.query_vram().await }));
|
||||
}
|
||||
// Small yield to give the senders a chance to actually send
|
||||
// before we issue the shutdown; not strictly necessary because
|
||||
// the channel is FIFO, but makes the test's intent clearer.
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
for f in futures {
|
||||
// Each query should have completed (Ok or Gone, never panic).
|
||||
let _ = f.await.expect("task did not panic");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
// llama.cpp harness implementation — Phase 11.
|
||||
@@ -1,163 +0,0 @@
|
||||
//! mistral.rs harness implementation.
|
||||
//!
|
||||
//! Wraps the mistral.rs HTTP API for model lifecycle management
|
||||
//! and optionally manages the process via systemd.
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec};
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
|
||||
pub struct MistralRsHarness {
|
||||
endpoint: String,
|
||||
systemd_unit: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl MistralRsHarness {
|
||||
pub fn new(endpoint: String, systemd_unit: Option<String>) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
systemd_unit,
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("failed to build HTTP client"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response from mistral.rs `GET /v1/models`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<ModelEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelEntry {
|
||||
id: String,
|
||||
#[serde(default)]
|
||||
status: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Harness for MistralRsHarness {
|
||||
fn name(&self) -> &str {
|
||||
"mistralrs"
|
||||
}
|
||||
|
||||
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
||||
let Some(unit) = &self.systemd_unit else {
|
||||
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
||||
};
|
||||
|
||||
let output = tokio::process::Command::new("systemctl")
|
||||
.args(["start", unit])
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("systemctl start {unit} failed: {stderr}");
|
||||
}
|
||||
|
||||
// Wait for the health endpoint to respond (up to 30s).
|
||||
let url = format!("{}/health", self.endpoint);
|
||||
for _ in 0..30 {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
if self.client.get(&url).send().await.is_ok() {
|
||||
tracing::info!(unit, "mistralrs started and healthy");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
anyhow::bail!("mistralrs started but health endpoint did not respond within 30s");
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<()> {
|
||||
let Some(unit) = &self.systemd_unit else {
|
||||
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
||||
};
|
||||
|
||||
let output = tokio::process::Command::new("systemctl")
|
||||
.args(["stop", unit])
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("systemctl stop {unit} failed: {stderr}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health(&self) -> HarnessHealth {
|
||||
let url = format!("{}/health", self.endpoint);
|
||||
let running = self.client.get(&url).send().await.is_ok();
|
||||
HarnessHealth {
|
||||
name: "mistralrs".into(),
|
||||
running,
|
||||
uptime_secs: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
let url = format!("{}/v1/models", self.endpoint);
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
anyhow::bail!("GET /v1/models returned {}", resp.status());
|
||||
}
|
||||
|
||||
let models_resp: ModelsResponse = resp.json().await?;
|
||||
Ok(models_resp
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|m| ModelInfo {
|
||||
id: m.id,
|
||||
harness: "mistralrs".into(),
|
||||
status: m.status.unwrap_or_else(|| "loaded".into()),
|
||||
devices: vec![],
|
||||
vram_used_mb: None,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
||||
let url = format!("{}/v1/models/reload", self.endpoint);
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({ "model_id": spec.model_id }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("POST /v1/models/reload failed: {body}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||||
let url = format!("{}/v1/models/unload", self.endpoint);
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({ "model_id": model_id }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("POST /v1/models/unload failed: {body}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn inference_endpoint(&self, _model_id: &str) -> Option<String> {
|
||||
// mistral.rs routes internally by model name in the request body,
|
||||
// so the inference endpoint is always the base URL.
|
||||
Some(self.endpoint.clone())
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,27 @@
|
||||
//! Harness registry — maps harness names to trait implementations.
|
||||
|
||||
pub mod llamacpp;
|
||||
pub mod mistralrs;
|
||||
pub mod arch;
|
||||
pub mod candle;
|
||||
pub mod chat_template;
|
||||
pub mod device_worker;
|
||||
pub mod preflight;
|
||||
pub mod tp;
|
||||
|
||||
use anyhow::Result;
|
||||
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Registry of available harness implementations.
|
||||
///
|
||||
/// Holds an `Arc<dyn Harness>` per harness for generic lifecycle dispatch
|
||||
/// (load/unload/list_models). When a candle harness is registered, a typed
|
||||
/// `Arc<CandleHarness>` is also cached so inference routes can bypass the
|
||||
/// dyn-Trait dispatch and reach harness-specific methods (chat completion,
|
||||
/// streaming, etc.).
|
||||
pub struct HarnessRegistry {
|
||||
harnesses: HashMap<String, Box<dyn Harness>>,
|
||||
harnesses: HashMap<String, Arc<dyn Harness>>,
|
||||
candle: Option<Arc<candle::CandleHarness>>,
|
||||
}
|
||||
|
||||
impl Default for HarnessRegistry {
|
||||
@@ -22,10 +34,11 @@ impl HarnessRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
harnesses: HashMap::new(),
|
||||
candle: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&mut self, harness: Box<dyn Harness>) {
|
||||
pub fn register(&mut self, harness: Arc<dyn Harness>) {
|
||||
self.harnesses.insert(harness.name().to_string(), harness);
|
||||
}
|
||||
|
||||
@@ -34,6 +47,12 @@ impl HarnessRegistry {
|
||||
self.harnesses.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Typed handle to the candle harness, if registered. Used by inference
|
||||
/// routes that need methods beyond the `Harness` trait surface.
|
||||
pub fn candle(&self) -> Option<Arc<candle::CandleHarness>> {
|
||||
self.candle.clone()
|
||||
}
|
||||
|
||||
/// List models from all registered harnesses.
|
||||
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
let mut all = Vec::new();
|
||||
@@ -81,19 +100,25 @@ impl HarnessRegistry {
|
||||
}
|
||||
|
||||
/// Build a registry from harness configs.
|
||||
pub fn from_configs(configs: &[HarnessConfig]) -> Self {
|
||||
///
|
||||
/// `bind_url` is the URL where this neuron serves inference (its own
|
||||
/// listen address). In-process harnesses (currently the only kind)
|
||||
/// return this URL from `inference_endpoint`.
|
||||
pub fn from_configs(
|
||||
configs: &[HarnessConfig],
|
||||
bind_url: &str,
|
||||
settings: &crate::config::HarnessSettings,
|
||||
) -> Self {
|
||||
let mut registry = Self::new();
|
||||
for config in configs {
|
||||
match config.name.as_str() {
|
||||
"mistralrs" => {
|
||||
if let Some(endpoint) = &config.endpoint {
|
||||
registry.register(Box::new(mistralrs::MistralRsHarness::new(
|
||||
endpoint.clone(),
|
||||
config.systemd_unit.clone(),
|
||||
)));
|
||||
} else {
|
||||
tracing::warn!("mistralrs harness missing endpoint, skipping");
|
||||
}
|
||||
"candle" => {
|
||||
let harness = Arc::new(candle::CandleHarness::new(
|
||||
bind_url.to_string(),
|
||||
settings.candle.hf_cache.clone(),
|
||||
));
|
||||
registry.candle = Some(Arc::clone(&harness));
|
||||
registry.harnesses.insert("candle".into(), harness);
|
||||
}
|
||||
other => {
|
||||
tracing::warn!(harness = other, "unknown harness type, skipping");
|
||||
|
||||
575
crates/neuron/src/harness/preflight.rs
Normal file
575
crates/neuron/src/harness/preflight.rs
Normal file
@@ -0,0 +1,575 @@
|
||||
//! Placement feasibility check that runs before any device allocation,
|
||||
//! NCCL handshake, or weight download.
|
||||
//!
|
||||
//! The loader path in `candle.rs` historically discovers an
|
||||
//! incompatibility *after* it has already started fetching files —
|
||||
//! "fetch config.json from HauhauCS/...: 404 Not Found" surfaces hours
|
||||
//! after operators set `tensor_parallel = 2` on a GGUF-only repo, with
|
||||
//! no hint about what's actually wrong. Preflight closes that gap:
|
||||
//!
|
||||
//! 1. one `repo.info()` round-trip (siblings listing, no blob fetch)
|
||||
//! 2. classify the repo: GGUF-only, dense safetensors, mixed, empty
|
||||
//! 3. apply the feasibility table against the requested
|
||||
//! `ModelSpec` (tp_size, quant)
|
||||
//! 4. return a structured `PreflightError` the API layer can map to
|
||||
//! 422 + JSON, or `Ok(PlacementPlan)` carrying the decisions the
|
||||
//! downstream load path needs (which GGUF file to fetch, etc.).
|
||||
//!
|
||||
//! Phase 2 of plan-source-aware-loader-preflight. The Phase 1 scheme
|
||||
//! work — `ModelSourceId` and per-scheme `SourceConfig` — is a
|
||||
//! separate PR; preflight runs against the single configured
|
||||
//! HuggingFace source for now and the scheme threading drops in
|
||||
//! cleanly when Phase 1 lands.
|
||||
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use hf_hub::api::tokio::Api;
|
||||
use serde::Serialize;
|
||||
|
||||
/// What the repo's siblings listing tells us about how to load it.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||
pub enum SourceFormat {
|
||||
/// Only GGUF files present. Single-GPU load path. `quants` is the
|
||||
/// lowercased filename list so the operator can be told what's
|
||||
/// actually available when their `quant=` choice doesn't match.
|
||||
Gguf { quants: Vec<String> },
|
||||
/// Dense safetensors (single-file or sharded via index.json).
|
||||
/// Goes through `load_arch_dense` on single-GPU, or `load_tp` (with
|
||||
/// optional in-situ quantization) when `tensor_parallel > 1`.
|
||||
DenseSafetensors { sharded: bool },
|
||||
/// Both safetensors and GGUF present — prefer the dense path
|
||||
/// because it composes with TP and ISQ. We surface the GGUF
|
||||
/// filenames anyway so operators with a strong preference can
|
||||
/// see they exist.
|
||||
Mixed { gguf_quants: Vec<String> },
|
||||
/// No recognised weight files. Either a tokenizer-only repo
|
||||
/// (e.g. some base-model repos that only host `tokenizer.json` and
|
||||
/// expect the operator to use a `-GGUF` sibling repo) or a
|
||||
/// genuinely empty entry.
|
||||
Empty,
|
||||
}
|
||||
|
||||
/// Output of `preflight` for a load that can proceed. Carries the
|
||||
/// decisions downstream resolve_* paths would otherwise re-derive.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct PlacementPlan {
|
||||
pub model_id: String,
|
||||
pub format: SourceFormat,
|
||||
pub tp_size: u32,
|
||||
/// Filename of the GGUF to fetch, populated when `format` is
|
||||
/// `Gguf` and a single-GPU load was requested. None for the
|
||||
/// dense/TP path.
|
||||
pub picked_quant_file: Option<String>,
|
||||
}
|
||||
|
||||
/// Structured failure modes. Each variant carries the fields the API
|
||||
/// layer needs to produce an actionable 422 body.
|
||||
#[derive(Debug, Clone, Serialize, thiserror::Error)]
|
||||
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||
pub enum PreflightError {
|
||||
/// `repo.info()` failed. Captures the underlying cause as a string
|
||||
/// so the operator log shows whether it's auth, 404, or transport.
|
||||
#[error("failed to fetch repo info for '{model_id}': {cause}")]
|
||||
RepoFetchFailed { model_id: String, cause: String },
|
||||
|
||||
/// The repo exists but has no recognised weight files.
|
||||
#[error(
|
||||
"repo '{model_id}' has no recognised weight files (no .gguf, no .safetensors); \
|
||||
a tokenizer-only repo cannot be loaded directly"
|
||||
)]
|
||||
EmptyRepo { model_id: String },
|
||||
|
||||
/// Operator asked for `tensor_parallel > 1` on a GGUF-only repo.
|
||||
/// The TP path requires safetensors+config for in-situ
|
||||
/// quantization; GGUF-TP isn't implemented (see CLAUDE.md).
|
||||
#[error(
|
||||
"cannot load '{model_id}' with tensor_parallel={tp_size}: repo is GGUF-only \
|
||||
({} .gguf files); TP requires dense safetensors. {suggestion}",
|
||||
gguf_quants.len()
|
||||
)]
|
||||
TpRequiresSafetensors {
|
||||
model_id: String,
|
||||
tp_size: u32,
|
||||
gguf_quants: Vec<String>,
|
||||
suggestion: String,
|
||||
},
|
||||
|
||||
/// Operator asked for a GGUF quant whose substring doesn't match
|
||||
/// any filename in the repo. `nearest` is a best-effort Levenshtein
|
||||
/// suggestion against the available quant names.
|
||||
#[error(
|
||||
"no GGUF file in '{model_id}' matches quant '{requested}'; \
|
||||
available: {available:?}{}",
|
||||
nearest.as_ref().map(|n| format!("; did you mean '{n}'?")).unwrap_or_default()
|
||||
)]
|
||||
QuantNotFound {
|
||||
model_id: String,
|
||||
requested: String,
|
||||
available: Vec<String>,
|
||||
nearest: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Run the placement check.
|
||||
///
|
||||
/// One network round-trip (`repo.info()`); no blob fetches. Returns
|
||||
/// `Ok(PlacementPlan)` when the requested combination is feasible, or
|
||||
/// a structured `PreflightError` describing what's wrong.
|
||||
pub async fn preflight(api: &Api, spec: &ModelSpec) -> Result<PlacementPlan, PreflightError> {
|
||||
let repo = api.model(spec.model_id.clone());
|
||||
let info = repo
|
||||
.info()
|
||||
.await
|
||||
.map_err(|e| PreflightError::RepoFetchFailed {
|
||||
model_id: spec.model_id.clone(),
|
||||
cause: format!("{e}"),
|
||||
})?;
|
||||
|
||||
let filenames: Vec<&str> = info.siblings.iter().map(|s| s.rfilename.as_str()).collect();
|
||||
let format = classify(&filenames);
|
||||
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||
|
||||
match (&format, tp_size, spec.quant.as_deref()) {
|
||||
// No weights at all — nothing to do.
|
||||
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
|
||||
model_id: spec.model_id.clone(),
|
||||
}),
|
||||
|
||||
// GGUF-only + TP: not supported. Today's HauhauCS failure.
|
||||
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
|
||||
Err(PreflightError::TpRequiresSafetensors {
|
||||
model_id: spec.model_id.clone(),
|
||||
tp_size: tp,
|
||||
gguf_quants: quants.clone(),
|
||||
suggestion: format!(
|
||||
"Set tensor_parallel=1 and pick a quant from {quants:?}, \
|
||||
or use a dense safetensors release of this model."
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
// GGUF-only + single-GPU: pick the file that matches the
|
||||
// operator's quant. Empty quant matches the first GGUF.
|
||||
(SourceFormat::Gguf { quants }, _, requested) => {
|
||||
let picked = pick_gguf_file(&filenames, requested.unwrap_or(""));
|
||||
match picked {
|
||||
Some(fname) => Ok(PlacementPlan {
|
||||
model_id: spec.model_id.clone(),
|
||||
format: format.clone(),
|
||||
tp_size,
|
||||
picked_quant_file: Some(fname),
|
||||
}),
|
||||
None => Err(PreflightError::QuantNotFound {
|
||||
model_id: spec.model_id.clone(),
|
||||
requested: requested.unwrap_or("").to_string(),
|
||||
available: quants.clone(),
|
||||
nearest: nearest_quant(requested.unwrap_or(""), quants),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Dense or mixed: dense path handles both single-GPU and TP.
|
||||
// The architecture compatibility check stays where it is —
|
||||
// `check_dense_config_supported` runs once `config.json` is
|
||||
// on disk, since it needs the parsed JSON.
|
||||
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
|
||||
Ok(PlacementPlan {
|
||||
model_id: spec.model_id.clone(),
|
||||
format: format.clone(),
|
||||
tp_size,
|
||||
picked_quant_file: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify a siblings file list into a `SourceFormat`. Pulled out so
|
||||
/// the unit tests can exercise it against fixture JSON without
|
||||
/// spinning up an Api.
|
||||
pub fn classify(filenames: &[&str]) -> SourceFormat {
|
||||
let mut gguf_quants: Vec<String> = filenames
|
||||
.iter()
|
||||
.filter(|f| f.to_lowercase().ends_with(".gguf"))
|
||||
.map(|f| f.to_lowercase())
|
||||
.collect();
|
||||
gguf_quants.sort();
|
||||
gguf_quants.dedup();
|
||||
|
||||
let has_safetensors = filenames.iter().any(|f| f.ends_with(".safetensors"));
|
||||
let sharded = filenames
|
||||
.iter()
|
||||
.any(|f| f.ends_with("model.safetensors.index.json"));
|
||||
|
||||
match (has_safetensors, gguf_quants.is_empty()) {
|
||||
(true, true) => SourceFormat::DenseSafetensors { sharded },
|
||||
(true, false) => SourceFormat::Mixed { gguf_quants },
|
||||
(false, false) => SourceFormat::Gguf {
|
||||
quants: gguf_quants,
|
||||
},
|
||||
(false, true) => SourceFormat::Empty,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mirror of the quant-matching logic in `candle.rs::resolve_files` so
|
||||
/// preflight picks the same file the downstream loader would. Empty
|
||||
/// quant returns the first `.gguf` (any quant). Lowercased substring
|
||||
/// match otherwise.
|
||||
fn pick_gguf_file(filenames: &[&str], quant_lc: &str) -> Option<String> {
|
||||
filenames
|
||||
.iter()
|
||||
.filter(|f| f.to_lowercase().ends_with(".gguf"))
|
||||
.find(|f| quant_lc.is_empty() || f.to_lowercase().contains(quant_lc))
|
||||
.map(|f| f.to_string())
|
||||
}
|
||||
|
||||
/// Best-effort suggestion when the operator's quant name doesn't
|
||||
/// substring-match any filename. Extracts the quant-ish token from
|
||||
/// each `.gguf` filename and picks the one with the smallest
|
||||
/// Levenshtein distance to the requested string. Returns None when
|
||||
/// the input is empty or no candidates exist.
|
||||
fn nearest_quant(requested: &str, candidates: &[String]) -> Option<String> {
|
||||
if requested.is_empty() || candidates.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Pull the "Q6_K_P"/"IQ4_XS"-ish token out of each filename for a
|
||||
// fairer comparison. Filenames look like
|
||||
// `Qwen3.6-27B-Uncensored-HauhauCS-Aggressive-Q6_K_P.gguf`, so the
|
||||
// quant is the last `-`-separated segment before the extension,
|
||||
// lowercased.
|
||||
let tokens: Vec<(String, String)> = candidates
|
||||
.iter()
|
||||
.map(|f| (extract_quant_token(f), f.clone()))
|
||||
.collect();
|
||||
|
||||
let req_lc = requested.to_lowercase();
|
||||
tokens
|
||||
.into_iter()
|
||||
.min_by_key(|(token, _)| levenshtein(&req_lc, token))
|
||||
.map(|(token, _)| token)
|
||||
}
|
||||
|
||||
fn extract_quant_token(filename: &str) -> String {
|
||||
let stem = filename
|
||||
.rsplit_once('.')
|
||||
.map(|(s, _)| s)
|
||||
.unwrap_or(filename);
|
||||
let token = stem.rsplit('-').next().unwrap_or(stem);
|
||||
token.to_lowercase()
|
||||
}
|
||||
|
||||
/// Iterative Levenshtein. Small inputs (quant names are <=12 chars),
|
||||
/// no need for the `levenshtein` crate.
|
||||
fn levenshtein(a: &str, b: &str) -> usize {
|
||||
let a: Vec<char> = a.chars().collect();
|
||||
let b: Vec<char> = b.chars().collect();
|
||||
let (m, n) = (a.len(), b.len());
|
||||
if m == 0 {
|
||||
return n;
|
||||
}
|
||||
if n == 0 {
|
||||
return m;
|
||||
}
|
||||
let mut prev: Vec<usize> = (0..=n).collect();
|
||||
let mut curr = vec![0usize; n + 1];
|
||||
for i in 1..=m {
|
||||
curr[0] = i;
|
||||
for j in 1..=n {
|
||||
let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
|
||||
curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
|
||||
}
|
||||
std::mem::swap(&mut prev, &mut curr);
|
||||
}
|
||||
prev[n]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn spec(model_id: &str, tp: Option<u32>, quant: Option<&str>) -> ModelSpec {
|
||||
ModelSpec {
|
||||
model_id: model_id.into(),
|
||||
harness: "candle".into(),
|
||||
quant: quant.map(String::from),
|
||||
tensor_parallel: tp,
|
||||
devices: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_gguf_only() {
|
||||
let files = [
|
||||
"README.md",
|
||||
".gitattributes",
|
||||
"Qwen3.6-27B-Q6_K_P.gguf",
|
||||
"Qwen3.6-27B-Q4_K_P.gguf",
|
||||
];
|
||||
match classify(&files) {
|
||||
SourceFormat::Gguf { quants } => {
|
||||
assert_eq!(quants.len(), 2);
|
||||
assert!(quants.iter().any(|q| q.contains("q6_k_p")));
|
||||
}
|
||||
other => panic!("expected Gguf, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_dense_sharded() {
|
||||
let files = [
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"model.safetensors.index.json",
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
];
|
||||
assert_eq!(
|
||||
classify(&files),
|
||||
SourceFormat::DenseSafetensors { sharded: true }
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_dense_single_file() {
|
||||
let files = ["config.json", "tokenizer.json", "model.safetensors"];
|
||||
assert_eq!(
|
||||
classify(&files),
|
||||
SourceFormat::DenseSafetensors { sharded: false }
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_mixed() {
|
||||
let files = [
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"model.safetensors",
|
||||
"model-Q4_K_M.gguf",
|
||||
];
|
||||
match classify(&files) {
|
||||
SourceFormat::Mixed { gguf_quants } => {
|
||||
assert_eq!(gguf_quants, vec!["model-q4_k_m.gguf"]);
|
||||
}
|
||||
other => panic!("expected Mixed, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_empty() {
|
||||
let files = ["README.md", "tokenizer.json"];
|
||||
assert_eq!(classify(&files), SourceFormat::Empty);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pick_gguf_substring_match() {
|
||||
let files = ["model-Q4_K_M.gguf", "model-Q6_K.gguf", "model-Q8_0.gguf"];
|
||||
assert_eq!(
|
||||
pick_gguf_file(&files, "q6_k"),
|
||||
Some("model-Q6_K.gguf".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pick_gguf_empty_returns_first() {
|
||||
let files = ["model-Q4_K_M.gguf", "model-Q6_K.gguf"];
|
||||
assert_eq!(pick_gguf_file(&files, ""), Some("model-Q4_K_M.gguf".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pick_gguf_no_match() {
|
||||
let files = ["model-Q4_K_M.gguf", "model-Q6_K.gguf"];
|
||||
assert_eq!(pick_gguf_file(&files, "iq2_xs"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nearest_quant_suggests_close_match() {
|
||||
// Today's HauhauCS scenario: operator wrote "q6k", actual
|
||||
// filename token is "q6_k_p". Should suggest the latter.
|
||||
let candidates = vec![
|
||||
"qwen-q4_k_p.gguf".to_string(),
|
||||
"qwen-q5_k_p.gguf".to_string(),
|
||||
"qwen-q6_k_p.gguf".to_string(),
|
||||
"qwen-q8_k_p.gguf".to_string(),
|
||||
];
|
||||
assert_eq!(nearest_quant("q6k", &candidates), Some("q6_k_p".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nearest_quant_empty_input() {
|
||||
assert_eq!(nearest_quant("", &[]), None);
|
||||
assert_eq!(nearest_quant("q6k", &[]), None);
|
||||
assert_eq!(nearest_quant("", &["model-q4.gguf".into()]), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_quant_handles_typical_filenames() {
|
||||
assert_eq!(extract_quant_token("Qwen3.6-27B-Q6_K_P.gguf"), "q6_k_p");
|
||||
assert_eq!(extract_quant_token("model-IQ4_XS.gguf"), "iq4_xs");
|
||||
assert_eq!(extract_quant_token("simple.gguf"), "simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn levenshtein_basics() {
|
||||
assert_eq!(levenshtein("", ""), 0);
|
||||
assert_eq!(levenshtein("abc", ""), 3);
|
||||
assert_eq!(levenshtein("", "abc"), 3);
|
||||
assert_eq!(levenshtein("kitten", "sitting"), 3);
|
||||
assert_eq!(levenshtein("q6k", "q6_k_p"), 3);
|
||||
assert_eq!(levenshtein("q6k", "q4_k_p"), 4);
|
||||
}
|
||||
|
||||
// Higher-level preflight tests below exercise the full feasibility
|
||||
// table via a thin wrapper that bypasses the network — we hand it
|
||||
// a pre-built `SourceFormat` and request shape, then drive the
|
||||
// same decision logic. The end-to-end test with a mock HTTP
|
||||
// server lives in tests/preflight.rs (integration).
|
||||
|
||||
/// Mirror of the `match` in `preflight()` but takes a classified
|
||||
/// `SourceFormat` directly. Lets us unit-test the feasibility
|
||||
/// table without making the API trait object-safe / boxable.
|
||||
fn decide(
|
||||
spec: &ModelSpec,
|
||||
format: &SourceFormat,
|
||||
filenames: &[&str],
|
||||
) -> Result<PlacementPlan, PreflightError> {
|
||||
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||
match (format, tp_size, spec.quant.as_deref()) {
|
||||
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
|
||||
model_id: spec.model_id.clone(),
|
||||
}),
|
||||
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
|
||||
Err(PreflightError::TpRequiresSafetensors {
|
||||
model_id: spec.model_id.clone(),
|
||||
tp_size: tp,
|
||||
gguf_quants: quants.clone(),
|
||||
suggestion: format!(
|
||||
"Set tensor_parallel=1 and pick a quant from {quants:?}, \
|
||||
or use a dense safetensors release of this model."
|
||||
),
|
||||
})
|
||||
}
|
||||
(SourceFormat::Gguf { quants }, _, requested) => {
|
||||
let picked = pick_gguf_file(filenames, requested.unwrap_or(""));
|
||||
match picked {
|
||||
Some(fname) => Ok(PlacementPlan {
|
||||
model_id: spec.model_id.clone(),
|
||||
format: format.clone(),
|
||||
tp_size,
|
||||
picked_quant_file: Some(fname),
|
||||
}),
|
||||
None => Err(PreflightError::QuantNotFound {
|
||||
model_id: spec.model_id.clone(),
|
||||
requested: requested.unwrap_or("").to_string(),
|
||||
available: quants.clone(),
|
||||
nearest: nearest_quant(requested.unwrap_or(""), quants),
|
||||
}),
|
||||
}
|
||||
}
|
||||
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
|
||||
Ok(PlacementPlan {
|
||||
model_id: spec.model_id.clone(),
|
||||
format: format.clone(),
|
||||
tp_size,
|
||||
picked_quant_file: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn feasibility_gguf_tp_rejected() {
|
||||
let files = ["Qwen-Q6_K_P.gguf", "Qwen-Q4_K_P.gguf"];
|
||||
let fmt = classify(&files);
|
||||
let s = spec("HauhauCS/Qwen3.6", Some(2), Some("q6k"));
|
||||
match decide(&s, &fmt, &files).unwrap_err() {
|
||||
PreflightError::TpRequiresSafetensors {
|
||||
model_id,
|
||||
tp_size,
|
||||
gguf_quants,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(model_id, "HauhauCS/Qwen3.6");
|
||||
assert_eq!(tp_size, 2);
|
||||
assert_eq!(gguf_quants.len(), 2);
|
||||
}
|
||||
other => panic!("expected TpRequiresSafetensors, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn feasibility_gguf_single_gpu_bad_quant() {
|
||||
let files = [
|
||||
"Qwen-Q4_K_P.gguf",
|
||||
"Qwen-Q5_K_P.gguf",
|
||||
"Qwen-Q6_K_P.gguf",
|
||||
"Qwen-Q8_K_P.gguf",
|
||||
];
|
||||
let fmt = classify(&files);
|
||||
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6k"));
|
||||
match decide(&s, &fmt, &files).unwrap_err() {
|
||||
PreflightError::QuantNotFound {
|
||||
requested,
|
||||
nearest,
|
||||
available,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(requested, "q6k");
|
||||
assert_eq!(nearest.as_deref(), Some("q6_k_p"));
|
||||
assert_eq!(available.len(), 4);
|
||||
}
|
||||
other => panic!("expected QuantNotFound, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn feasibility_gguf_single_gpu_good_quant() {
|
||||
let files = ["Qwen-Q4_K_M.gguf", "Qwen-Q6_K.gguf"];
|
||||
let fmt = classify(&files);
|
||||
let s = spec("Qwen/Q-GGUF", Some(1), Some("q6_k"));
|
||||
let plan = decide(&s, &fmt, &files).unwrap();
|
||||
assert_eq!(plan.picked_quant_file.as_deref(), Some("Qwen-Q6_K.gguf"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn feasibility_dense_tp_ok() {
|
||||
let files = [
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"model.safetensors.index.json",
|
||||
"model-00001-of-00002.safetensors",
|
||||
];
|
||||
let fmt = classify(&files);
|
||||
let s = spec("Qwen/Q3-30B", Some(2), Some("q5k"));
|
||||
let plan = decide(&s, &fmt, &files).unwrap();
|
||||
assert_eq!(plan.tp_size, 2);
|
||||
assert!(plan.picked_quant_file.is_none());
|
||||
assert!(matches!(
|
||||
plan.format,
|
||||
SourceFormat::DenseSafetensors { sharded: true }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn feasibility_empty_rejected() {
|
||||
let files = ["README.md", "tokenizer.json"];
|
||||
let fmt = classify(&files);
|
||||
let s = spec("Empty/Repo", Some(1), None);
|
||||
match decide(&s, &fmt, &files).unwrap_err() {
|
||||
PreflightError::EmptyRepo { model_id } => assert_eq!(model_id, "Empty/Repo"),
|
||||
other => panic!("expected EmptyRepo, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_serialization_carries_kind_field() {
|
||||
let err = PreflightError::TpRequiresSafetensors {
|
||||
model_id: "x/y".into(),
|
||||
tp_size: 2,
|
||||
gguf_quants: vec!["q6_k_p".into()],
|
||||
suggestion: "...".into(),
|
||||
};
|
||||
let v: serde_json::Value = serde_json::to_value(&err).unwrap();
|
||||
assert_eq!(v["kind"], "tp_requires_safetensors");
|
||||
assert_eq!(v["model_id"], "x/y");
|
||||
assert_eq!(v["tp_size"], 2);
|
||||
}
|
||||
}
|
||||
119
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
119
crates/neuron/src/harness/tp/all_reduce.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
|
||||
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
|
||||
//!
|
||||
//! Ported from the canonical
|
||||
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
|
||||
//! Row-parallel layers apply this op after their local matmul to sum
|
||||
//! partial outputs across NCCL ranks.
|
||||
//!
|
||||
//! Available only under `--features cuda`; on CPU builds this module
|
||||
//! is empty and row-parallel layers degenerate to local matmul only
|
||||
//! (useful for compile-checking the model code; correctness requires
|
||||
//! cuda).
|
||||
//!
|
||||
//! Thread-safety caveat: NCCL communicators are technically only
|
||||
//! safe to use from a single thread at a time
|
||||
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
|
||||
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
|
||||
//! against it from the dedicated `spawn_blocking` thread the inference
|
||||
//! pipeline already uses for candle's forward passes.
|
||||
|
||||
#![cfg(feature = "cuda")]
|
||||
|
||||
use candle_core::backend::BackendStorage;
|
||||
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
|
||||
use cudarc::nccl::{Comm, ReduceOp};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
|
||||
/// graph as a custom op. Each row-parallel layer holds one of these.
|
||||
pub struct AllReduce {
|
||||
comm: Arc<Comm>,
|
||||
}
|
||||
|
||||
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
|
||||
// that issuing ops against one comm from multiple threads concurrently
|
||||
// is unsafe. We serialise via the single spawn_blocking thread that
|
||||
// drives the model's forward pass. The Send/Sync impl is necessary
|
||||
// because candle's CustomOp1 trait bounds require it; the correctness
|
||||
// invariant is enforced at the call site, not the type level.
|
||||
unsafe impl Send for AllReduce {}
|
||||
unsafe impl Sync for AllReduce {}
|
||||
|
||||
impl AllReduce {
|
||||
pub fn new(comm: Arc<Comm>) -> Self {
|
||||
Self { comm }
|
||||
}
|
||||
|
||||
pub fn comm(&self) -> &Arc<Comm> {
|
||||
&self.comm
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomOp1 for AllReduce {
|
||||
fn name(&self) -> &'static str {
|
||||
"neuron.tp.all_reduce"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
|
||||
}
|
||||
|
||||
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||
// Reject non-contiguous inputs explicitly — copying them
|
||||
// server-side would mask shape bugs (a TP layer feeding a
|
||||
// strided activation into all_reduce is almost certainly a
|
||||
// model construction error).
|
||||
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
|
||||
slice: &cudarc::driver::CudaSlice<T>,
|
||||
l: &Layout,
|
||||
) -> Result<()> {
|
||||
match l.contiguous_offsets() {
|
||||
Some((0, n)) if n == slice.len() => Ok(()),
|
||||
_ => candle_core::bail!(
|
||||
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
|
||||
l,
|
||||
slice.len()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dev = s.device().clone();
|
||||
|
||||
let out = match s.dtype() {
|
||||
DType::BF16 => {
|
||||
let src = s.as_cuda_slice::<bf16>()?;
|
||||
require_contiguous(src, l)?;
|
||||
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
||||
self.comm
|
||||
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
|
||||
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
DType::F16 => {
|
||||
let src = s.as_cuda_slice::<f16>()?;
|
||||
require_contiguous(src, l)?;
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||
self.comm
|
||||
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
|
||||
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
DType::F32 => {
|
||||
let src = s.as_cuda_slice::<f32>()?;
|
||||
require_contiguous(src, l)?;
|
||||
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
self.comm
|
||||
.all_reduce(src, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
|
||||
CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
dtype => candle_core::bail!(
|
||||
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
|
||||
),
|
||||
};
|
||||
Ok((out, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
213
crates/neuron/src/harness/tp/fused_load.rs
Normal file
213
crates/neuron/src/harness/tp/fused_load.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
//! Direct safetensors readers for fused-region weight tensors.
|
||||
//!
|
||||
//! Qwen3-Next's `in_proj_qkv` and `conv1d` weights are *fused* —
|
||||
//! three regions stored sequentially along dim 0 (`[key_q, key_k,
|
||||
//! value]`). The per-rank shard for each region has unequal size
|
||||
//! (`key_dim/ws` vs `value_dim/ws`), so candle's `ShardedSafeTensors`
|
||||
//! built-in `Shard { dim, rank, world_size }` (uniform split) doesn't
|
||||
//! map to the right slices.
|
||||
//!
|
||||
//! The previous approach loaded the full fused tensor onto the device,
|
||||
//! `narrow`ed the three regions, and `Tensor::cat(...).contiguous()`'d
|
||||
//! the per-rank slice. That left ~100 MB of transient device memory
|
||||
//! per linear-attention layer — 48 layers × 100 MB = ~4.8 GB of
|
||||
//! allocator pressure during load, enough to trigger fragmentation
|
||||
//! OOM on tight-VRAM consumer GPUs.
|
||||
//!
|
||||
//! This module reads the three per-rank byte ranges *directly from
|
||||
//! the safetensors mmap* (host-side), concatenates them into a single
|
||||
//! contiguous byte buffer, and uploads as one device allocation. No
|
||||
//! full-tensor device materialisation.
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use candle_core::safetensors::MmapedSafetensors;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
|
||||
/// Read a 2D fused-QKV tensor `[conv_dim, hidden_size]` and return
|
||||
/// this rank's per-region slice as a `[per_rank_conv_dim, hidden_size]`
|
||||
/// device tensor.
|
||||
///
|
||||
/// `tensor_name` must be the fully-qualified safetensors key (e.g.
|
||||
/// `"model.language_model.layers.5.linear_attn.in_proj_qkv.weight"`).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load_fused_qkv_2d(
|
||||
mmap: &MmapedSafetensors,
|
||||
tensor_name: &str,
|
||||
hidden_size: usize,
|
||||
key_dim: usize,
|
||||
value_dim: usize,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
target_dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let ws = world_size as usize;
|
||||
let r = rank as usize;
|
||||
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||
bail!(
|
||||
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||
must each be divisible by world_size ({ws})"
|
||||
);
|
||||
}
|
||||
let per_rank_key = key_dim / ws;
|
||||
let per_rank_value = value_dim / ws;
|
||||
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
|
||||
|
||||
let view = mmap
|
||||
.get(tensor_name)
|
||||
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 2D"))?;
|
||||
let view_dtype: DType = view
|
||||
.dtype()
|
||||
.try_into()
|
||||
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
|
||||
|
||||
let shape = view.shape();
|
||||
if shape.len() != 2 {
|
||||
bail!(
|
||||
"fused qkv tensor '{tensor_name}' has shape {shape:?}, expected 2D \
|
||||
[conv_dim, hidden_size]"
|
||||
);
|
||||
}
|
||||
let conv_dim = key_dim * 2 + value_dim;
|
||||
if shape[0] != conv_dim || shape[1] != hidden_size {
|
||||
bail!(
|
||||
"fused qkv tensor '{tensor_name}' shape {shape:?} \
|
||||
doesn't match expected [{conv_dim}, {hidden_size}]"
|
||||
);
|
||||
}
|
||||
|
||||
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
|
||||
let k_bytes = slice_dim0_bytes(
|
||||
&view,
|
||||
key_dim + r * per_rank_key,
|
||||
per_rank_key,
|
||||
tensor_name,
|
||||
"k",
|
||||
)?;
|
||||
let v_bytes = slice_dim0_bytes(
|
||||
&view,
|
||||
2 * key_dim + r * per_rank_value,
|
||||
per_rank_value,
|
||||
tensor_name,
|
||||
"v",
|
||||
)?;
|
||||
|
||||
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
|
||||
bytes.extend_from_slice(&q_bytes);
|
||||
bytes.extend_from_slice(&k_bytes);
|
||||
bytes.extend_from_slice(&v_bytes);
|
||||
|
||||
let tensor = Tensor::from_raw_buffer(
|
||||
&bytes,
|
||||
view_dtype,
|
||||
&[per_rank_conv_dim, hidden_size],
|
||||
device,
|
||||
)
|
||||
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused qkv '{tensor_name}'"))?;
|
||||
tensor
|
||||
.to_dtype(target_dtype)
|
||||
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
|
||||
}
|
||||
|
||||
/// Read a 3D fused-QKV tensor `[conv_dim, 1, kernel_size]` (the
|
||||
/// depthwise conv1d weight) and return this rank's per-region slice
|
||||
/// as a `[per_rank_conv_dim, 1, kernel_size]` device tensor.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn load_fused_qkv_3d(
|
||||
mmap: &MmapedSafetensors,
|
||||
tensor_name: &str,
|
||||
mid: usize,
|
||||
kernel_size: usize,
|
||||
key_dim: usize,
|
||||
value_dim: usize,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
target_dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let ws = world_size as usize;
|
||||
let r = rank as usize;
|
||||
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
|
||||
bail!(
|
||||
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
|
||||
must each be divisible by world_size ({ws})"
|
||||
);
|
||||
}
|
||||
let per_rank_key = key_dim / ws;
|
||||
let per_rank_value = value_dim / ws;
|
||||
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
|
||||
|
||||
let view = mmap
|
||||
.get(tensor_name)
|
||||
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 3D"))?;
|
||||
let view_dtype: DType = view
|
||||
.dtype()
|
||||
.try_into()
|
||||
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
|
||||
|
||||
let shape = view.shape();
|
||||
if shape.len() != 3 {
|
||||
bail!(
|
||||
"fused conv tensor '{tensor_name}' has shape {shape:?}, expected 3D \
|
||||
[conv_dim, mid, kernel_size]"
|
||||
);
|
||||
}
|
||||
let conv_dim = key_dim * 2 + value_dim;
|
||||
if shape[0] != conv_dim || shape[1] != mid || shape[2] != kernel_size {
|
||||
bail!(
|
||||
"fused conv tensor '{tensor_name}' shape {shape:?} \
|
||||
doesn't match expected [{conv_dim}, {mid}, {kernel_size}]"
|
||||
);
|
||||
}
|
||||
|
||||
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
|
||||
let k_bytes = slice_dim0_bytes(
|
||||
&view,
|
||||
key_dim + r * per_rank_key,
|
||||
per_rank_key,
|
||||
tensor_name,
|
||||
"k",
|
||||
)?;
|
||||
let v_bytes = slice_dim0_bytes(
|
||||
&view,
|
||||
2 * key_dim + r * per_rank_value,
|
||||
per_rank_value,
|
||||
tensor_name,
|
||||
"v",
|
||||
)?;
|
||||
|
||||
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
|
||||
bytes.extend_from_slice(&q_bytes);
|
||||
bytes.extend_from_slice(&k_bytes);
|
||||
bytes.extend_from_slice(&v_bytes);
|
||||
|
||||
let tensor = Tensor::from_raw_buffer(
|
||||
&bytes,
|
||||
view_dtype,
|
||||
&[per_rank_conv_dim, mid, kernel_size],
|
||||
device,
|
||||
)
|
||||
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused conv '{tensor_name}'"))?;
|
||||
tensor
|
||||
.to_dtype(target_dtype)
|
||||
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
|
||||
}
|
||||
|
||||
/// Read `len` consecutive rows along dim 0 starting at `start` from
|
||||
/// the safetensors view, returning the raw bytes. Wraps the same
|
||||
/// `view.slice(start..stop)` machinery that candle's
|
||||
/// `ShardedSafeTensors::get` uses internally.
|
||||
fn slice_dim0_bytes(
|
||||
view: &safetensors::tensor::TensorView<'_>,
|
||||
start: usize,
|
||||
len: usize,
|
||||
tensor_name: &str,
|
||||
region: &str,
|
||||
) -> Result<Vec<u8>> {
|
||||
use safetensors::slice::IndexOp;
|
||||
let stop = start + len;
|
||||
let iter = view.slice(start..stop).map_err(|e| {
|
||||
anyhow::anyhow!("slice '{tensor_name}' region {region} ({start}..{stop}): {e:?}")
|
||||
})?;
|
||||
Ok(iter.into_iter().flatten().copied().collect())
|
||||
}
|
||||
795
crates/neuron/src/harness/tp/mod.rs
Normal file
795
crates/neuron/src/harness/tp/mod.rs
Normal file
@@ -0,0 +1,795 @@
|
||||
//! Tensor-parallel inference plumbing.
|
||||
//!
|
||||
//! The leader process (the neuron daemon proper) drives one
|
||||
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
|
||||
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
|
||||
//! and talks to each over a newline-delimited JSON RPC channel on
|
||||
//! the worker's stdin/stdout (see `rpc.rs`).
|
||||
//!
|
||||
//! Sub-staging:
|
||||
//!
|
||||
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
|
||||
//! forks N workers; `ping` round-trips every worker to confirm
|
||||
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
|
||||
//! `NcclSanityCheck` are stubbed.
|
||||
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
|
||||
//! `NcclSanityCheck`. CUDA-gated.
|
||||
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
|
||||
//! - **7c:** crash detection, streaming SSE, graceful unload.
|
||||
|
||||
pub mod all_reduce;
|
||||
pub mod fused_load;
|
||||
pub mod nccl_state;
|
||||
pub mod rpc;
|
||||
pub mod tp_linear;
|
||||
pub mod tp_qwen3;
|
||||
pub mod tp_qwen3_5;
|
||||
pub mod worker;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Stdio;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
|
||||
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||
|
||||
use rpc::{WorkerRequest, WorkerResponse};
|
||||
|
||||
/// Leader-side handle for any TP-loaded model. The pool's
|
||||
/// `load_dense_shard` dispatches on `config.json#/model_type` to build
|
||||
/// the right variant; downstream callers (the harness's
|
||||
/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`,
|
||||
/// `unload_model`) all hold this enum and let the variant dispatch
|
||||
/// determine the concrete forward.
|
||||
///
|
||||
/// Variants gated on `cuda` because the underlying TP models hold
|
||||
/// `Arc<cudarc::nccl::Comm>` references — irrelevant on CPU builds.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub enum TpLeaderModel {
|
||||
Qwen3(tp_qwen3::TpQwen3ForCausalLM),
|
||||
Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM),
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
impl TpLeaderModel {
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input: &candle_core::Tensor,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
TpLeaderModel::Qwen3(m) => m.forward(input, offset),
|
||||
TpLeaderModel::Qwen3_5(m) => m.forward(input, offset),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
match self {
|
||||
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
|
||||
TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &candle_core::Device {
|
||||
match self {
|
||||
TpLeaderModel::Qwen3(m) => m.device(),
|
||||
TpLeaderModel::Qwen3_5(m) => m.device(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// One worker subprocess plus its bidirectional stdio handles.
|
||||
struct Worker {
|
||||
rank: u32,
|
||||
/// Captured so the leader can log "spawned rank N on device M" and
|
||||
/// future stages can re-issue Init after a CUDA reset. Unused in
|
||||
/// the Stage 7a-i RPC paths themselves.
|
||||
#[allow(dead_code)]
|
||||
cuda_device: u32,
|
||||
child: Child,
|
||||
stdin: ChildStdin,
|
||||
stdout: Lines<BufReader<ChildStdout>>,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
/// Send a request and wait for the response. Used for sequenced
|
||||
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
|
||||
/// overlap the worker's execution with the leader's.
|
||||
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
|
||||
self.send_only(req).await?;
|
||||
self.recv_only().await
|
||||
}
|
||||
|
||||
/// Write a request without awaiting its response. Pair with
|
||||
/// `recv_only` from the caller when leader and worker need to do
|
||||
/// work concurrently — e.g. during `Init`, where the leader
|
||||
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
|
||||
/// workers, then collects `InitOk` after NCCL completes.
|
||||
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
|
||||
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
|
||||
line.push('\n');
|
||||
self.stdin
|
||||
.write_all(line.as_bytes())
|
||||
.await
|
||||
.with_context(|| format!("write request to rank {}", self.rank))?;
|
||||
self.stdin
|
||||
.flush()
|
||||
.await
|
||||
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_only(&mut self) -> Result<WorkerResponse> {
|
||||
let reply = self
|
||||
.stdout
|
||||
.next_line()
|
||||
.await
|
||||
.with_context(|| format!("read reply from rank {}", self.rank))?
|
||||
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
|
||||
serde_json::from_str(&reply)
|
||||
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
|
||||
}
|
||||
}
|
||||
|
||||
/// Drain one response from every worker, classifying each via the
|
||||
/// supplied checker. Always reads from every worker — even if some
|
||||
/// fail — so the next call's recv doesn't pick up stale responses
|
||||
/// from this one (pipe-poisoning was the cause of the
|
||||
/// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class
|
||||
/// of bugs).
|
||||
///
|
||||
/// Returns a vector of `rank N: detail` strings for any worker that
|
||||
/// errored, expected-mismatched, or failed to respond. Caller decides
|
||||
/// how to combine these with the leader's outcome.
|
||||
async fn drain_workers(
|
||||
workers: &mut [Worker],
|
||||
mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>,
|
||||
) -> Vec<String> {
|
||||
let mut errs = Vec::new();
|
||||
for w in workers {
|
||||
match w.recv_only().await {
|
||||
Ok(resp) => {
|
||||
if let Err(detail) = check(resp) {
|
||||
errs.push(format!("rank {} {detail}", w.rank));
|
||||
}
|
||||
}
|
||||
Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)),
|
||||
}
|
||||
}
|
||||
errs
|
||||
}
|
||||
|
||||
/// Combine a leader's `Result<Result<T>>` (the typical
|
||||
/// `spawn_blocking → JoinHandle<Result<T>>` shape) with the worker
|
||||
/// drain results into a single `Result<T>`. Leader failures take
|
||||
/// precedence in the error message but worker errors get appended so
|
||||
/// the operator sees both halves.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn combine_leader_workers<T>(
|
||||
leader: Result<Result<T>>,
|
||||
worker_errors: Vec<String>,
|
||||
op: &str,
|
||||
) -> Result<T> {
|
||||
match leader {
|
||||
Ok(Ok(value)) => {
|
||||
if worker_errors.is_empty() {
|
||||
Ok(value)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"{op}: leader succeeded but workers failed: {}",
|
||||
worker_errors.join("; ")
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
if worker_errors.is_empty() {
|
||||
Err(e.context(format!("{op}: leader forward failed")))
|
||||
} else {
|
||||
Err(e.context(format!(
|
||||
"{op}: leader forward failed and workers also failed: {}",
|
||||
worker_errors.join("; ")
|
||||
)))
|
||||
}
|
||||
}
|
||||
Err(panic_err) => {
|
||||
if worker_errors.is_empty() {
|
||||
Err(panic_err)
|
||||
} else {
|
||||
Err(panic_err.context(format!(
|
||||
"{op}: leader task panicked and workers failed: {}",
|
||||
worker_errors.join("; ")
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A live pool of worker subprocesses. Owns the `Child` handles so
|
||||
/// dropping the pool kills the children; explicit `shutdown()` is
|
||||
/// the graceful path.
|
||||
pub struct WorkerPool {
|
||||
world_size: u32,
|
||||
workers: Vec<Worker>,
|
||||
/// Path to the neuron binary used to launch workers.
|
||||
#[allow(dead_code)]
|
||||
exe: PathBuf,
|
||||
/// The leader's per-device CUDA worker thread. Phase 3 moved the
|
||||
/// leader's `NcclState` (rank-0 NCCL Comm) into this thread, so
|
||||
/// every NCCL op (init, sanity, all_reduce inside forward) issues
|
||||
/// from one OS thread for the daemon's lifetime. The handle is
|
||||
/// also used by `load_dense_shard` to clone the leader's
|
||||
/// `Arc<Comm>` for the row-parallel layers' AllReduce ops; in
|
||||
/// Phase 4 the load itself moves onto the worker and that bridge
|
||||
/// goes away.
|
||||
pub(crate) leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
||||
}
|
||||
|
||||
impl WorkerPool {
|
||||
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
||||
/// leader (in-process) and is *not* spawned here — the leader
|
||||
/// holds rank 0's NCCL Comm and shard in its own address space.
|
||||
///
|
||||
/// `binary` is the path to the neuron executable to run for each
|
||||
/// worker (production passes `/proc/self/exe`; tests pass the
|
||||
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
|
||||
/// `cuda_devices` is one entry per rank including rank 0. Worker
|
||||
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
|
||||
pub async fn spawn(
|
||||
binary: &Path,
|
||||
world_size: u32,
|
||||
cuda_devices: &[u32],
|
||||
leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
||||
) -> Result<Self> {
|
||||
if world_size < 2 {
|
||||
anyhow::bail!(
|
||||
"WorkerPool::spawn called with world_size={world_size}; \
|
||||
use the single-process path for world_size < 2"
|
||||
);
|
||||
}
|
||||
if cuda_devices.len() as u32 != world_size {
|
||||
anyhow::bail!(
|
||||
"expected {world_size} cuda_devices entries, got {}",
|
||||
cuda_devices.len()
|
||||
);
|
||||
}
|
||||
let exe = binary.to_path_buf();
|
||||
|
||||
let mut workers = Vec::with_capacity(world_size as usize - 1);
|
||||
// Rank 0 stays in-process. Spawn ranks 1..world_size.
|
||||
for rank in 1..world_size {
|
||||
let cuda_device = cuda_devices[rank as usize];
|
||||
let mut cmd = Command::new(&exe);
|
||||
cmd.arg("--worker")
|
||||
.arg("--rank")
|
||||
.arg(rank.to_string())
|
||||
.arg("--tp-size")
|
||||
.arg(world_size.to_string())
|
||||
.arg("--cuda-device")
|
||||
.arg(cuda_device.to_string())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
// Inherit stderr so worker tracing surfaces alongside
|
||||
// the leader's journalctl stream.
|
||||
.stderr(Stdio::inherit())
|
||||
.kill_on_drop(true);
|
||||
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.with_context(|| format!("spawn worker rank {rank}"))?;
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
|
||||
let stdout = BufReader::new(stdout).lines();
|
||||
|
||||
workers.push(Worker {
|
||||
rank,
|
||||
cuda_device,
|
||||
child,
|
||||
stdin,
|
||||
stdout,
|
||||
});
|
||||
tracing::info!(rank, cuda_device, "spawned tp worker");
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
world_size,
|
||||
workers,
|
||||
exe,
|
||||
leader_worker,
|
||||
})
|
||||
}
|
||||
|
||||
/// Establish the NCCL communicator across the leader (rank 0) and
|
||||
/// every worker subprocess. Rendezvous is via a freshly-generated
|
||||
/// `Id` broadcast over the RPC stream; the actual handshake blocks
|
||||
/// inside `Comm::from_rank` until all `world_size` ranks check in.
|
||||
///
|
||||
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
|
||||
/// to — typically the first entry of the `cuda_devices` slice
|
||||
/// originally passed to `spawn()`.
|
||||
///
|
||||
/// On the non-cuda build this immediately fails because the leader
|
||||
/// can't generate an `Id` without libnccl. The same call works in
|
||||
/// the worker path (returning a no-cuda error response) so the
|
||||
/// failure surface is uniform.
|
||||
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
|
||||
let comm_id = nccl_state::generate_comm_id_hex()
|
||||
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
|
||||
|
||||
// 1. Write Init to every worker's stdin without awaiting the
|
||||
// response. Workers will parse and call Comm::from_rank
|
||||
// concurrently with the leader below.
|
||||
for w in &mut self.workers {
|
||||
let req = WorkerRequest::Init {
|
||||
comm_id: comm_id.clone(),
|
||||
};
|
||||
w.send_only(&req).await?;
|
||||
}
|
||||
|
||||
// 2. Leader rank 0 calls Comm::from_rank on its own device.
|
||||
// Phase 3 moved this from spawn_blocking onto the leader's
|
||||
// device worker thread (`Job::NcclInit`); the underlying
|
||||
// `Comm` now lives on the same OS thread for its entire
|
||||
// lifetime, including every later `Comm::all_reduce` issued
|
||||
// by the row-parallel layers during forward.
|
||||
//
|
||||
// NCCL's init blocks until every rank has called in — the
|
||||
// subprocess workers above and the leader's device worker
|
||||
// here. The Job's reply unblocks when the leader's
|
||||
// Comm::from_rank returns.
|
||||
let leader_cfg = worker::WorkerConfig {
|
||||
rank: 0,
|
||||
world_size: self.world_size,
|
||||
cuda_device: leader_cuda_device,
|
||||
};
|
||||
let leader_resp = self
|
||||
.leader_worker
|
||||
.nccl_init(leader_cfg, comm_id.clone())
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("leader NCCL init via device worker: {e}"))?;
|
||||
match leader_resp {
|
||||
rpc::WorkerResponse::InitOk => {}
|
||||
rpc::WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
|
||||
}
|
||||
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
|
||||
}
|
||||
|
||||
// 3. Read InitOk from each worker. By now every worker has
|
||||
// completed its Comm::from_rank call (NCCL released them
|
||||
// when the leader joined the handshake) and is writing its
|
||||
// response.
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match &resp {
|
||||
rpc::WorkerResponse::InitOk => {}
|
||||
rpc::WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} init: expected InitOk, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
world_size = self.world_size,
|
||||
"NCCL communicator established across all ranks"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate the NCCL communicator: every rank `all_reduce`s a
|
||||
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
|
||||
/// `world_size`. Confirms the handshake is live, not just
|
||||
/// configured.
|
||||
///
|
||||
/// Must be called after `init_nccl()`; before that the leader has
|
||||
/// no Comm and the workers reply with `nccl_not_initialised`.
|
||||
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
|
||||
// 1. Trigger the all_reduce on every worker (write-only).
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
|
||||
}
|
||||
|
||||
// 2. Leader's own all_reduce, on its device worker thread.
|
||||
// NCCL operations block until every rank participates;
|
||||
// Job::NcclSanity returns once the leader's side completes
|
||||
// (which happens when every subprocess worker reaches its
|
||||
// all_reduce call too).
|
||||
let leader_resp = self
|
||||
.leader_worker
|
||||
.nccl_sanity()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("leader NCCL sanity via device worker: {e}"))?;
|
||||
|
||||
let expected = self.world_size;
|
||||
let leader_sum = match leader_resp {
|
||||
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
|
||||
rpc::WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
|
||||
}
|
||||
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
|
||||
};
|
||||
if leader_sum != expected {
|
||||
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
|
||||
}
|
||||
|
||||
// 3. Read sanity result from each worker. All must match
|
||||
// world_size — anything else means the collective didn't
|
||||
// complete consistently across ranks.
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
rpc::WorkerResponse::NcclSanityResult { observed_sum }
|
||||
if observed_sum == expected => {}
|
||||
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||
anyhow::bail!(
|
||||
"worker rank {} observed_sum={observed_sum}, expected {expected}",
|
||||
w.rank
|
||||
);
|
||||
}
|
||||
rpc::WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
|
||||
}
|
||||
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
world_size = expected,
|
||||
"NCCL sanity check OK across all ranks"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Ping every worker and return their Pong payloads in rank order.
|
||||
/// Useful right after `spawn` to confirm the lifecycle plumbing is
|
||||
/// intact before kicking off any heavier work.
|
||||
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
|
||||
let mut out = Vec::with_capacity(self.workers.len());
|
||||
for w in &mut self.workers {
|
||||
let resp = w.request(&WorkerRequest::Ping).await?;
|
||||
match &resp {
|
||||
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
|
||||
WorkerResponse::Pong { rank, .. } => {
|
||||
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
|
||||
}
|
||||
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
|
||||
}
|
||||
out.push(resp);
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Load this rank's shard of a dense Qwen3 model on every rank.
|
||||
///
|
||||
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
|
||||
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
|
||||
/// shards in their own address spaces and confirm via
|
||||
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
|
||||
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
|
||||
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
|
||||
/// of the full model (plus the replicated embedding/norm/lm_head).
|
||||
///
|
||||
/// `leader_device` is the candle `Device` the leader's shard lives
|
||||
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
|
||||
/// the same index passed to `init_nccl`. `dtype` is the on-device
|
||||
/// element type; bf16 is the canonical Qwen3 distribution dtype.
|
||||
///
|
||||
/// `init_nccl` must have completed first. Bails if the leader's
|
||||
/// NCCL comm isn't set up yet.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn load_dense_shard(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
config_json: &str,
|
||||
safetensors_paths: &[std::path::PathBuf],
|
||||
_leader_device: &candle_core::Device,
|
||||
dtype: candle_core::DType,
|
||||
quant: Option<String>,
|
||||
) -> Result<super::device_worker::TpHandle> {
|
||||
let world_size = self.world_size;
|
||||
let safetensors_str: Vec<String> = safetensors_paths
|
||||
.iter()
|
||||
.map(|p| p.to_string_lossy().into_owned())
|
||||
.collect();
|
||||
|
||||
// 1. Fan out the LoadDenseShard request to every subprocess
|
||||
// worker without awaiting their replies — they'll build
|
||||
// their shards in parallel with the leader below.
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::LoadDenseShard {
|
||||
model_id: model_id.to_string(),
|
||||
config_json: config_json.to_string(),
|
||||
safetensors_paths: safetensors_str.clone(),
|
||||
quant: quant.clone(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// 2. Build rank 0's shard on the leader's device worker
|
||||
// thread. Phase 4 moved the load itself onto the worker —
|
||||
// the dispatch handler reads `state.nccl.comm()` directly
|
||||
// so the leader's `Arc<Comm>` clones embedded in the
|
||||
// row-parallel layers are constructed and used on the same
|
||||
// OS thread for the model's entire lifetime. No
|
||||
// spawn_blocking, no SendComm bridge.
|
||||
let handle = self
|
||||
.leader_worker
|
||||
.tp_load_shard(
|
||||
model_id.to_string(),
|
||||
config_json.to_string(),
|
||||
safetensors_paths.to_vec(),
|
||||
dtype,
|
||||
quant.clone(),
|
||||
world_size,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("leader TP shard load via device worker: {e}"))?;
|
||||
|
||||
// 3. Collect worker confirmations. Anything other than
|
||||
// LoadDenseShardOk aborts the whole load — the leader's
|
||||
// already-inserted shard would leak in the worker slab
|
||||
// until the daemon restarts; an explicit DropTp would be
|
||||
// cleaner but the failure here is rare and the operator's
|
||||
// next step is to restart anyway.
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::LoadDenseShardOk => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// Run one forward step across every rank. The leader's forward
|
||||
/// runs on the device worker thread via `Job::TpForwardLogits` and
|
||||
/// returns CPU-side `[vocab]` logits as `Vec<f32>`; the async
|
||||
/// caller wraps them in a CPU tensor for `apply_repeat_penalty` +
|
||||
/// sampling without holding a device-resident tensor on a tokio
|
||||
/// thread.
|
||||
///
|
||||
/// Subprocess workers run their own forwards in parallel (the
|
||||
/// AllReduce CustomOps inside row-parallel layers are what let
|
||||
/// the leader's collective complete) and reply with
|
||||
/// `GenerateStepOk` over the RPC stream — they do not ship logits.
|
||||
///
|
||||
/// `tokens` is the input for this step (prompt for prefill, the
|
||||
/// previously-sampled token for decode). `offset` is the KV-cache
|
||||
/// position before this step.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn generate_step(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
leader_handle: super::device_worker::TpHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
let step_start = std::time::Instant::now();
|
||||
let tokens_len = tokens.len();
|
||||
tracing::debug!(
|
||||
model = %model_id,
|
||||
tokens = tokens_len,
|
||||
offset,
|
||||
"WorkerPool::generate_step: fan-out"
|
||||
);
|
||||
// 1. Fan-out to subprocess workers.
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::GenerateStep {
|
||||
model_id: model_id.to_string(),
|
||||
tokens: tokens.clone(),
|
||||
offset,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
// 2. Leader's forward on its device worker thread. The
|
||||
// AllReduce CustomOps inside the row-parallel layers block
|
||||
// until every subprocess worker's forward issues the
|
||||
// matching collective. Returning CPU-side `Vec<f32>` keeps
|
||||
// the device tensor from escaping the worker thread —
|
||||
// that's the invariant the whole refactor exists to
|
||||
// preserve.
|
||||
let leader_start = std::time::Instant::now();
|
||||
let leader_result = self
|
||||
.leader_worker
|
||||
.tp_forward_logits(leader_handle, tokens, offset)
|
||||
.await;
|
||||
let leader_ok = leader_result.is_ok();
|
||||
let leader_ms = leader_start.elapsed().as_millis();
|
||||
// Surface the leader's own error at WARN before draining
|
||||
// workers so the operator can correlate it with whatever the
|
||||
// subprocess workers logged. Previously this was silently
|
||||
// coerced to a bool.
|
||||
if !leader_ok {
|
||||
let detail = leader_result
|
||||
.as_ref()
|
||||
.err()
|
||||
.map(|e| format!("{e:#}"))
|
||||
.unwrap_or_default();
|
||||
tracing::warn!(
|
||||
model = %model_id,
|
||||
tokens = tokens_len,
|
||||
offset,
|
||||
leader_ms,
|
||||
error = %detail,
|
||||
"WorkerPool::generate_step: leader forward failed"
|
||||
);
|
||||
}
|
||||
tracing::debug!(
|
||||
model = %model_id,
|
||||
tokens = tokens_len,
|
||||
leader_ms,
|
||||
leader_ok,
|
||||
"WorkerPool::generate_step: leader forward returned"
|
||||
);
|
||||
|
||||
// 3. ALWAYS drain worker responses, regardless of whether the
|
||||
// leader succeeded. Skipping this on the leader's error
|
||||
// path leaves stale GenerateStepOk replies in the worker
|
||||
// pipes that poison the NEXT request's recv (was seeing
|
||||
// "ClearKvCache: expected KvCacheCleared, got
|
||||
// GenerateStepOk" the call after any forward-time failure).
|
||||
let drain_start = std::time::Instant::now();
|
||||
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||
WorkerResponse::GenerateStepOk => Ok(()),
|
||||
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||
other => Err(format!("expected GenerateStepOk, got {other:?}")),
|
||||
})
|
||||
.await;
|
||||
tracing::debug!(
|
||||
model = %model_id,
|
||||
drain_ms = drain_start.elapsed().as_millis(),
|
||||
errors = worker_errors.len(),
|
||||
total_ms = step_start.elapsed().as_millis(),
|
||||
"WorkerPool::generate_step: workers drained"
|
||||
);
|
||||
|
||||
// Combine the leader's Result + the workers' string-error
|
||||
// list. Phase 3 inlines this because the upstream
|
||||
// `combine_leader_workers` expects the spawn_blocking-shaped
|
||||
// `Result<Result<T>>`; the new device-worker path produces a
|
||||
// single `Result<T, WorkerError>` instead.
|
||||
match leader_result {
|
||||
Ok(values) => {
|
||||
if worker_errors.is_empty() {
|
||||
Ok(values)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"GenerateStep: leader succeeded but workers failed: {}",
|
||||
worker_errors.join("; ")
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if worker_errors.is_empty() {
|
||||
Err(anyhow::Error::new(e).context("GenerateStep: leader forward failed"))
|
||||
} else {
|
||||
Err(anyhow::Error::new(e).context(format!(
|
||||
"GenerateStep: leader forward failed and workers also failed: {}",
|
||||
worker_errors.join("; ")
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the KV cache for `model_id` on every rank. Called at the
|
||||
/// start of every inference so a fresh request doesn't attend over
|
||||
/// the previous one's tokens.
|
||||
pub async fn clear_kv_cache(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
#[cfg(feature = "cuda")] leader_handle: super::device_worker::TpHandle,
|
||||
) -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::ClearKvCache {
|
||||
model_id: model_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
// Leader-side clear on the device worker thread —
|
||||
// `TpLeaderModel::clear_kv_cache` is infallible but still
|
||||
// routes through Job::TpClearKv so the cache reset runs
|
||||
// on the same thread that owns the model's CUDA tensors.
|
||||
if let Err(e) = self.leader_worker.tp_clear_kv(leader_handle).await {
|
||||
anyhow::bail!("leader TP clear_kv_cache via device worker: {e}");
|
||||
}
|
||||
}
|
||||
// Drain workers — same rationale as `generate_step`. The
|
||||
// leader's clear_kv_cache is now async-via-channel but still
|
||||
// returns before the drain so the workers' KvCacheCleared
|
||||
// replies are processed in order.
|
||||
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||
WorkerResponse::KvCacheCleared => Ok(()),
|
||||
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||
other => Err(format!("expected KvCacheCleared, got {other:?}")),
|
||||
})
|
||||
.await;
|
||||
tracing::debug!(
|
||||
model = %model_id,
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
errors = worker_errors.len(),
|
||||
"WorkerPool::clear_kv_cache: workers drained"
|
||||
);
|
||||
if !worker_errors.is_empty() {
|
||||
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Drop this model's shards on every rank. The leader's shard is
|
||||
/// expected to have been dropped by the caller (its `Arc` was held
|
||||
/// in the TpLoadedModel and goes away when that's removed).
|
||||
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
|
||||
for w in &mut self.workers {
|
||||
w.send_only(&WorkerRequest::UnloadModel {
|
||||
model_id: model_id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
for w in &mut self.workers {
|
||||
let resp = w.recv_only().await?;
|
||||
match resp {
|
||||
WorkerResponse::Unloaded => {}
|
||||
WorkerResponse::Error { kind, message } => {
|
||||
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
|
||||
w.rank
|
||||
),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
|
||||
/// children. Best-effort — individual worker failures are logged
|
||||
/// but don't abort the rest of the sweep.
|
||||
pub async fn shutdown(mut self) -> Result<()> {
|
||||
for w in &mut self.workers {
|
||||
match w.request(&WorkerRequest::Shutdown).await {
|
||||
Ok(WorkerResponse::Bye) => {}
|
||||
Ok(other) => tracing::warn!(
|
||||
rank = w.rank,
|
||||
response = ?other,
|
||||
"expected Bye on shutdown"
|
||||
),
|
||||
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
|
||||
}
|
||||
}
|
||||
for w in &mut self.workers {
|
||||
match w.child.wait().await {
|
||||
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
|
||||
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn world_size(&self) -> u32 {
|
||||
self.world_size
|
||||
}
|
||||
|
||||
pub fn binary_path(&self) -> &PathBuf {
|
||||
&self.exe
|
||||
}
|
||||
}
|
||||
293
crates/neuron/src/harness/tp/nccl_state.rs
Normal file
293
crates/neuron/src/harness/tp/nccl_state.rs
Normal file
@@ -0,0 +1,293 @@
|
||||
//! NCCL state held by both the worker process and the leader's pool.
|
||||
//!
|
||||
//! Split into its own module so the worker (`tp/worker.rs`) and the
|
||||
//! leader (`tp/mod.rs`) share the same hex-encoding/decoding code and
|
||||
//! the same shape of `Option<Comm>` state machine.
|
||||
//!
|
||||
//! When the `cuda` feature is off, `NcclState` is a zero-sized
|
||||
//! placeholder that returns `Error{kind="cuda_feature_not_enabled"}`
|
||||
//! from every operation. When it's on, the same struct holds the
|
||||
//! actual `cudarc::nccl::Comm`.
|
||||
|
||||
use super::rpc::WorkerResponse;
|
||||
use super::worker::WorkerConfig;
|
||||
|
||||
/// Encode bytes as lowercase hex. Used for ferrying NCCL `Id::internal()`
|
||||
/// across the leader→worker RPC boundary inside a JSON string.
|
||||
pub fn encode_hex(bytes: &[u8]) -> String {
|
||||
let mut out = String::with_capacity(bytes.len() * 2);
|
||||
for b in bytes {
|
||||
use std::fmt::Write;
|
||||
let _ = write!(out, "{b:02x}");
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Decode lowercase-or-uppercase hex into bytes. Errors on odd length
|
||||
/// or non-hex characters; the caller bubbles those up via the RPC's
|
||||
/// `Error{kind="bad_request"}` variant.
|
||||
pub fn decode_hex(s: &str) -> Result<Vec<u8>, String> {
|
||||
if !s.len().is_multiple_of(2) {
|
||||
return Err(format!("hex string has odd length {}", s.len()));
|
||||
}
|
||||
(0..s.len())
|
||||
.step_by(2)
|
||||
.map(|i| {
|
||||
u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("bad hex byte at {i}: {e}"))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub struct NcclState;
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
impl Default for NcclState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
impl NcclState {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub fn init(&mut self, _cfg: WorkerConfig, _comm_id_hex: &str) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "this neuron binary was built without --features cuda; \
|
||||
NCCL Init requires CUDA"
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sanity_check(&mut self) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "NCCL sanity check requires --features cuda".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda_impl {
|
||||
use super::*;
|
||||
use cudarc::driver::CudaContext;
|
||||
use cudarc::nccl::{Comm, Id, ReduceOp};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Number of bytes in NCCL's unique-id type; matches `Id::internal()`'s
|
||||
/// `[c_char; 128]`. Wire-encoded as 256 lowercase hex chars.
|
||||
const NCCL_ID_BYTES: usize = 128;
|
||||
|
||||
pub struct NcclState {
|
||||
/// Wrapped in `Arc` so we can hand a clone to `TpQwen3ForCausalLM`
|
||||
/// at load time (every row-parallel layer needs a reference to
|
||||
/// run its trailing `AllReduce`). The `Arc` is the single source
|
||||
/// of truth for the comm's lifetime — when the pool drops and
|
||||
/// every layer that captured a clone drops, NCCL releases the
|
||||
/// underlying `ncclComm_t`.
|
||||
comm: Option<Arc<Comm>>,
|
||||
/// Held alongside the Comm so the device isn't dropped
|
||||
/// underneath the NCCL handle.
|
||||
#[allow(dead_code)]
|
||||
ctx: Option<Arc<CudaContext>>,
|
||||
}
|
||||
|
||||
impl Default for NcclState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl NcclState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
comm: None,
|
||||
ctx: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Clone the comm out as an `Arc` so callers (the leader-side
|
||||
/// `TpQwen3ForCausalLM::load`, or the worker's own model load)
|
||||
/// can hold a reference for the lifetime of the model. Returns
|
||||
/// `None` before `init` has run.
|
||||
pub fn comm(&self) -> Option<Arc<Comm>> {
|
||||
self.comm.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
|
||||
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
|
||||
/// given comm must be serialised", not "the handle must stay on the
|
||||
/// thread that created it" — so it's safe to move an `Arc<Comm>`
|
||||
/// across threads as long as no concurrent ops are issued. The
|
||||
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
|
||||
/// wrapper at the move boundary is the only thing missing.
|
||||
///
|
||||
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
|
||||
/// by the row-parallel layers are only used from the
|
||||
/// `spawn_blocking` thread driving the forward pass; concurrent
|
||||
/// access from another thread would still be a bug.
|
||||
pub struct SendComm(pub Arc<Comm>);
|
||||
|
||||
// SAFETY: see the doc-comment above; the invariant is enforced at
|
||||
// the call site (pool Mutex + single spawn_blocking thread), not at
|
||||
// the type level.
|
||||
unsafe impl Send for SendComm {}
|
||||
unsafe impl Sync for SendComm {}
|
||||
|
||||
impl SendComm {
|
||||
pub fn into_inner(self) -> Arc<Comm> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
||||
// (libnccl-allocated state). NCCL requires that operations against
|
||||
// one Comm be issued one at a time; we serialise access by storing
|
||||
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
|
||||
// move-safe — NCCL doesn't track the calling OS thread, only the
|
||||
// stream the operations are dispatched against.
|
||||
unsafe impl Send for NcclState {}
|
||||
unsafe impl Sync for NcclState {}
|
||||
|
||||
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
|
||||
/// the leader to mint the shared communicator id which is then
|
||||
/// broadcast to every worker via the RPC `Init` message.
|
||||
pub fn generate_comm_id_hex() -> Result<String, String> {
|
||||
// NcclError lacks a Display impl in cudarc 0.19.x — surface
|
||||
// via Debug throughout this module.
|
||||
let id = Id::new().map_err(|e| format!("Id::new(): {e:?}"))?;
|
||||
let bytes_u8: [u8; NCCL_ID_BYTES] = std::array::from_fn(|i| id.internal()[i] as u8);
|
||||
Ok(encode_hex(&bytes_u8))
|
||||
}
|
||||
|
||||
impl NcclState {
|
||||
pub fn init(&mut self, cfg: WorkerConfig, comm_id_hex: &str) -> WorkerResponse {
|
||||
match try_init(self, cfg, comm_id_hex) {
|
||||
Ok(()) => WorkerResponse::InitOk,
|
||||
Err(msg) => WorkerResponse::Error {
|
||||
kind: "nccl_init_failed".into(),
|
||||
message: msg,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sanity_check(&mut self) -> WorkerResponse {
|
||||
let Some(comm) = self.comm.as_ref() else {
|
||||
return WorkerResponse::Error {
|
||||
kind: "nccl_not_initialised".into(),
|
||||
message: "sanity_check requires Init to have completed first".into(),
|
||||
};
|
||||
};
|
||||
match try_sanity_check(comm.as_ref()) {
|
||||
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
|
||||
Err(msg) => WorkerResponse::Error {
|
||||
kind: "nccl_sanity_failed".into(),
|
||||
message: msg,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_init(state: &mut NcclState, cfg: WorkerConfig, comm_id_hex: &str) -> Result<(), String> {
|
||||
let bytes = decode_hex(comm_id_hex)?;
|
||||
if bytes.len() != NCCL_ID_BYTES {
|
||||
return Err(format!(
|
||||
"comm_id is {} bytes, expected {NCCL_ID_BYTES}",
|
||||
bytes.len()
|
||||
));
|
||||
}
|
||||
let id_bytes: [std::ffi::c_char; NCCL_ID_BYTES] =
|
||||
std::array::from_fn(|i| bytes[i] as std::ffi::c_char);
|
||||
let id = Id::uninit(id_bytes);
|
||||
|
||||
let ctx = CudaContext::new(cfg.cuda_device as usize)
|
||||
.map_err(|e| format!("CudaContext::new({}) failed: {e}", cfg.cuda_device))?;
|
||||
let stream = ctx.default_stream();
|
||||
let comm = Comm::from_rank(stream, cfg.rank as usize, cfg.world_size as usize, id)
|
||||
.map_err(|e| {
|
||||
format!(
|
||||
"Comm::from_rank(rank={}, world={}) failed: {e:?}",
|
||||
cfg.rank, cfg.world_size
|
||||
)
|
||||
})?;
|
||||
|
||||
state.ctx = Some(ctx);
|
||||
// `Comm` is !Send + !Sync at the type level because it wraps a
|
||||
// raw `ncclComm_t`. The `Arc` is fine in practice — we
|
||||
// serialise operations through the pool's outer Mutex and the
|
||||
// SendComm wrapper at thread-crossing boundaries enforces this
|
||||
// at every move site. clippy's `arc_with_non_send_sync` lint
|
||||
// can't see that invariant; allow once at the canonical
|
||||
// construction site.
|
||||
#[allow(clippy::arc_with_non_send_sync)]
|
||||
{
|
||||
state.comm = Some(Arc::new(comm));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn try_sanity_check(comm: &Comm) -> Result<u32, String> {
|
||||
let stream = comm.stream().clone();
|
||||
let input = stream
|
||||
.clone_htod(&[1u32])
|
||||
.map_err(|e| format!("htod sentinel: {e}"))?;
|
||||
let mut output = stream
|
||||
.alloc_zeros::<u32>(1)
|
||||
.map_err(|e| format!("alloc output: {e}"))?;
|
||||
// cudarc::nccl::NcclError doesn't impl Display in 0.19.x —
|
||||
// surface via Debug so we still see the variant + ncclResult
|
||||
// code instead of a generic "{e}" failure.
|
||||
comm.all_reduce(&input, &mut output, &ReduceOp::Sum)
|
||||
.map_err(|e| format!("all_reduce: {e:?}"))?;
|
||||
let result = stream
|
||||
.clone_dtoh(&output)
|
||||
.map_err(|e| format!("dtoh result: {e}"))?;
|
||||
Ok(result[0])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use cuda_impl::{NcclState, SendComm, generate_comm_id_hex};
|
||||
|
||||
/// Non-cuda stub for the leader: returns a clear marker error rather
|
||||
/// than letting `init_nccl` succeed vacuously.
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub fn generate_comm_id_hex() -> Result<String, String> {
|
||||
Err("cuda_feature_not_enabled: build with --features cuda".into())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn hex_roundtrip() {
|
||||
let original: Vec<u8> = (0u8..=255).collect();
|
||||
let encoded = encode_hex(&original);
|
||||
assert_eq!(encoded.len(), 512);
|
||||
let decoded = decode_hex(&encoded).expect("decode");
|
||||
assert_eq!(decoded, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hex_decode_rejects_odd_length() {
|
||||
assert!(decode_hex("a").is_err());
|
||||
assert!(decode_hex("abc").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hex_decode_rejects_non_hex() {
|
||||
assert!(decode_hex("zz").is_err());
|
||||
assert!(decode_hex("ab_d").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hex_encode_is_lowercase_padded() {
|
||||
assert_eq!(encode_hex(&[0x0a, 0xff]), "0aff");
|
||||
}
|
||||
}
|
||||
257
crates/neuron/src/harness/tp/rpc.rs
Normal file
257
crates/neuron/src/harness/tp/rpc.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
//! Wire protocol between the neuron leader process and its
|
||||
//! `--worker` subprocesses.
|
||||
//!
|
||||
//! Every frame is one newline-delimited JSON object on the worker's
|
||||
//! stdin (request) or stdout (response). Both directions are tagged
|
||||
//! sum types from the start so new ops in Stage 7b/7c slot in without
|
||||
//! breaking compatibility — no "14 message types and a version field"
|
||||
//! drift later. Adding a new variant is the canonical way to evolve
|
||||
//! the protocol; existing peers that don't recognise an op return
|
||||
//! `WorkerResponse::Error { kind: "unknown_op", .. }`.
|
||||
//!
|
||||
//! The serialised shape uses `tag = "op"` so a request looks like:
|
||||
//! {"op":"ping"}
|
||||
//! {"op":"init","comm_id":"a1b2..."}
|
||||
//! and a response:
|
||||
//! {"op":"pong","rank":0,"world_size":2,"cuda_device":0}
|
||||
//! {"op":"error","kind":"nccl_init_failed","message":"..."}
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Leader → worker. Worker handles one at a time; replies with exactly
|
||||
/// one `WorkerResponse` per request.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "op", rename_all = "snake_case")]
|
||||
pub enum WorkerRequest {
|
||||
/// Liveness probe. Worker replies with `Pong` containing its own
|
||||
/// identity. Used by the leader to confirm the subprocess is up
|
||||
/// and ready before kicking off any heavier work.
|
||||
Ping,
|
||||
|
||||
/// One-shot NCCL communicator setup. The leader generates the
|
||||
/// `comm_id` once (rank 0 of NCCL), broadcasts it to every worker
|
||||
/// via this message, then every rank (leader included) calls
|
||||
/// `Comm::from_rank` with the same id — NCCL blocks until all
|
||||
/// `world_size` ranks check in. The hex-encoded bytes are the
|
||||
/// canonical `cudarc::nccl::Id::internal()` content.
|
||||
Init {
|
||||
/// Hex-encoded NCCL id bytes (128 bytes → 256 hex chars).
|
||||
comm_id: String,
|
||||
},
|
||||
|
||||
/// Sanity check: after Init, every rank runs an `all_reduce` over
|
||||
/// a sentinel value (`1u32`). The expected sum is `world_size`.
|
||||
/// Worker replies with the observed value so the leader can verify
|
||||
/// the NCCL handshake is genuinely live, not just configured.
|
||||
NcclSanityCheck,
|
||||
|
||||
/// Load this rank's shard of a dense Qwen3 model from mmaped
|
||||
/// safetensors. The same `safetensors_paths` list is sent to every
|
||||
/// rank — the ShardedVarBuilder reads only the rank-local slice of
|
||||
/// each tensor at materialisation time, so the worker's VRAM
|
||||
/// footprint is `1 / world_size` of the full model (plus replicated
|
||||
/// embedding/norm/lm_head).
|
||||
LoadDenseShard {
|
||||
/// Caller-supplied id for later `GenerateStep` / `UnloadModel`
|
||||
/// lookups. Typically the HF model id verbatim.
|
||||
model_id: String,
|
||||
/// JSON-serialised `candle_transformers::models::qwen3::Config`
|
||||
/// — the same blob the leader parsed from the HF cache's
|
||||
/// `config.json`. Threaded through verbatim so the worker uses
|
||||
/// identical hyperparameters.
|
||||
config_json: String,
|
||||
/// Absolute paths the worker should mmap. The same set on every
|
||||
/// rank; ShardedVarBuilder slices into them per rank.
|
||||
safetensors_paths: Vec<String>,
|
||||
/// Optional in-situ quantization dtype (e.g. "q5k", "q8_0",
|
||||
/// "q6k"). When set, each linear-layer weight is quantized
|
||||
/// at load time to the named ggml format — saves ~3-5x vs
|
||||
/// bf16/f16 at the cost of some accuracy. `None` keeps the
|
||||
/// weights in the on-disk dtype (typically bf16).
|
||||
#[serde(default)]
|
||||
quant: Option<String>,
|
||||
},
|
||||
|
||||
/// Run one forward step on this rank's loaded model. The worker
|
||||
/// reaches into its NCCL Comm for the row-parallel `AllReduce`s
|
||||
/// inside the model — and so blocks on every other rank issuing the
|
||||
/// same op. The leader does *not* receive logits back over RPC; it
|
||||
/// runs its own rank-0 forward in parallel and uses its own logits
|
||||
/// for sampling.
|
||||
GenerateStep {
|
||||
model_id: String,
|
||||
/// Input token ids for this step. For prefill, the whole prompt;
|
||||
/// for decode, a single token. Identical on every rank.
|
||||
tokens: Vec<u32>,
|
||||
/// KV cache offset (count of tokens already in the cache before
|
||||
/// this step).
|
||||
offset: usize,
|
||||
},
|
||||
|
||||
/// Reset the KV cache for this model on this rank. Sent at the
|
||||
/// start of every inference so a fresh request doesn't accidentally
|
||||
/// attend over the previous one's tokens.
|
||||
ClearKvCache { model_id: String },
|
||||
|
||||
/// Drop this rank's shard for the given model. Releases the VRAM
|
||||
/// the shard's weights occupied; subsequent `GenerateStep` calls
|
||||
/// against the same `model_id` return an `Error`.
|
||||
UnloadModel { model_id: String },
|
||||
|
||||
/// Worker should release resources and exit. Worker replies `Bye`
|
||||
/// and then closes stdout / exits zero. The leader reaps the
|
||||
/// child via the `tokio::process::Child` it kept.
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
/// Worker → leader. Always exactly one of these per `WorkerRequest`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "op", rename_all = "snake_case")]
|
||||
pub enum WorkerResponse {
|
||||
/// Reply to `Ping`. Carries enough identity for the leader to log
|
||||
/// what it actually got back.
|
||||
Pong {
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
cuda_device: u32,
|
||||
},
|
||||
|
||||
/// Reply to `Init`. Empty payload — success is the absence of
|
||||
/// `Error`. NCCL's internal blocking handshake means by the time
|
||||
/// this comes back, every other rank has also reached
|
||||
/// `Comm::from_rank`.
|
||||
InitOk,
|
||||
|
||||
/// Reply to `NcclSanityCheck`. The observed sum after a single
|
||||
/// `all_reduce(SUM, 1u32)` across all ranks. The leader checks
|
||||
/// this matches `world_size`.
|
||||
NcclSanityResult { observed_sum: u32 },
|
||||
|
||||
/// Reply to `LoadDenseShard`. Empty payload — success is the
|
||||
/// absence of `Error`. By the time this comes back, the rank's
|
||||
/// `TpQwen3ForCausalLM` is constructed in memory and ready for
|
||||
/// `GenerateStep`.
|
||||
LoadDenseShardOk,
|
||||
|
||||
/// Reply to `GenerateStep`. Empty payload — workers don't ship
|
||||
/// logits over the wire. The leader uses its own rank-0 logits;
|
||||
/// workers only need to confirm the collective completed.
|
||||
GenerateStepOk,
|
||||
|
||||
/// Reply to `ClearKvCache`. Empty payload.
|
||||
KvCacheCleared,
|
||||
|
||||
/// Reply to `UnloadModel`. Empty payload. The named model is no
|
||||
/// longer present on this rank.
|
||||
Unloaded,
|
||||
|
||||
/// Reply to `Shutdown`. Worker exits immediately after writing this.
|
||||
Bye,
|
||||
|
||||
/// Any request can produce this instead of its dedicated success
|
||||
/// variant. `kind` is a machine-readable category so the leader
|
||||
/// can branch on failure mode without string-matching `message`.
|
||||
Error {
|
||||
/// Short tag — `nccl_init_failed`, `unknown_op`, etc.
|
||||
kind: String,
|
||||
/// Human-readable detail for logs.
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn roundtrip<T>(value: &T) -> T
|
||||
where
|
||||
T: Serialize + for<'de> Deserialize<'de>,
|
||||
{
|
||||
serde_json::from_str(&serde_json::to_string(value).expect("serialise"))
|
||||
.expect("deserialise")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_ping_round_trip() {
|
||||
let req = WorkerRequest::Ping;
|
||||
let wire = serde_json::to_string(&req).unwrap();
|
||||
assert_eq!(wire, r#"{"op":"ping"}"#);
|
||||
match roundtrip(&req) {
|
||||
WorkerRequest::Ping => {}
|
||||
other => panic!("expected Ping, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_init_carries_hex_id() {
|
||||
let req = WorkerRequest::Init {
|
||||
comm_id: "deadbeef".into(),
|
||||
};
|
||||
let wire = serde_json::to_string(&req).unwrap();
|
||||
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_shutdown_round_trip() {
|
||||
assert_eq!(
|
||||
serde_json::to_string(&WorkerRequest::Shutdown).unwrap(),
|
||||
r#"{"op":"shutdown"}"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_pong_round_trip() {
|
||||
let resp = WorkerResponse::Pong {
|
||||
rank: 1,
|
||||
world_size: 4,
|
||||
cuda_device: 1,
|
||||
};
|
||||
let wire = serde_json::to_string(&resp).unwrap();
|
||||
assert!(wire.contains(r#""op":"pong""#));
|
||||
assert!(wire.contains(r#""rank":1"#));
|
||||
assert!(wire.contains(r#""world_size":4"#));
|
||||
match roundtrip(&resp) {
|
||||
WorkerResponse::Pong {
|
||||
rank,
|
||||
world_size,
|
||||
cuda_device,
|
||||
} => {
|
||||
assert_eq!(rank, 1);
|
||||
assert_eq!(world_size, 4);
|
||||
assert_eq!(cuda_device, 1);
|
||||
}
|
||||
other => panic!("expected Pong, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_error_carries_kind_and_message() {
|
||||
let resp = WorkerResponse::Error {
|
||||
kind: "nccl_init_failed".into(),
|
||||
message: "could not bind device".into(),
|
||||
};
|
||||
let wire = serde_json::to_string(&resp).unwrap();
|
||||
assert!(wire.contains(r#""op":"error""#));
|
||||
assert!(wire.contains(r#""kind":"nccl_init_failed""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_sanity_result_round_trip() {
|
||||
let resp = WorkerResponse::NcclSanityResult { observed_sum: 4 };
|
||||
match roundtrip(&resp) {
|
||||
WorkerResponse::NcclSanityResult { observed_sum } => {
|
||||
assert_eq!(observed_sum, 4);
|
||||
}
|
||||
other => panic!("expected NcclSanityResult, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Unknown ops on the wire deserialise to an error rather than
|
||||
/// silently mis-matching — confirms our `serde(tag = "op")`
|
||||
/// configuration rejects unknowns instead of doing fuzzy matching.
|
||||
#[test]
|
||||
fn unknown_op_fails_to_parse() {
|
||||
let result: Result<WorkerRequest, _> = serde_json::from_str(r#"{"op":"explode"}"#);
|
||||
assert!(result.is_err(), "should reject unknown op, got {result:?}");
|
||||
}
|
||||
}
|
||||
283
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
283
crates/neuron/src/harness/tp/tp_linear.rs
Normal file
@@ -0,0 +1,283 @@
|
||||
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
|
||||
//! and `Shard` sharding hints.
|
||||
//!
|
||||
//! candle reads only the rank's slice of each weight tensor from
|
||||
//! safetensors via `view.slice(start..stop)` — no full-tensor host
|
||||
//! materialisation. That's a memory-efficiency win over hand-rolled
|
||||
//! "load full + narrow" sharding (which the earlier
|
||||
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
|
||||
//!
|
||||
//! Two layer types:
|
||||
//!
|
||||
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
|
||||
//! local matmul. The downstream consumer either accepts a sharded
|
||||
//! activation (next layer is also column-parallel) or all-gathers.
|
||||
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
|
||||
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
|
||||
//!
|
||||
//! Both assume **no bias** — every Qwen3-family weight layout we
|
||||
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
|
||||
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
|
||||
//! support is mechanical when a future model needs it; the design
|
||||
//! choice would be: column-parallel shards the bias along dim 0;
|
||||
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
|
||||
//! sum carries it exactly once.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::quantized::{GgmlDType, QMatMul, QTensor};
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use super::all_reduce::AllReduce;
|
||||
|
||||
/// Linear primitive that holds either a plain `Linear` (bf16/f16/f32)
|
||||
/// or a quantized `QMatMul` (Q4K/Q5K/Q6K/Q8_0/etc.).
|
||||
///
|
||||
/// Constructed via [`MaybeQuantLinear::from_weight`] — pass `None` to
|
||||
/// keep the weight in its loaded dtype (no quantization), or
|
||||
/// `Some(dtype)` to quantize at load time.
|
||||
///
|
||||
/// On the forward path the two arms dispatch identically: `Module::forward`
|
||||
/// returns an output in the caller's input dtype (f32 fallback for the
|
||||
/// quantized matmul). Subsequent ops don't need to know whether the
|
||||
/// layer was quantized.
|
||||
pub enum MaybeQuantLinear {
|
||||
Plain(Linear),
|
||||
Quant(QMatMul),
|
||||
}
|
||||
|
||||
impl MaybeQuantLinear {
|
||||
/// Build a linear from a loaded weight tensor. If `quant` is set,
|
||||
/// the weight is quantized in-situ and stored as a `QMatMul`;
|
||||
/// otherwise it's wrapped in a plain `Linear`.
|
||||
pub fn from_weight(weight: Tensor, quant: Option<GgmlDType>) -> Result<Self> {
|
||||
match quant {
|
||||
Some(dtype) => {
|
||||
let qt = QTensor::quantize(&weight, dtype).with_context(|| {
|
||||
format!(
|
||||
"QTensor::quantize to {dtype:?} for shape {:?}",
|
||||
weight.shape()
|
||||
)
|
||||
})?;
|
||||
let qmm = QMatMul::from_arc(Arc::new(qt))
|
||||
.context("QMatMul::from_arc on freshly quantized weight")?;
|
||||
Ok(Self::Quant(qmm))
|
||||
}
|
||||
None => Ok(Self::Plain(Linear::new(weight, None))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Above this M (the product of all input dims except the last)
|
||||
/// dispatch the quantized matmul through `QMatMul::forward_via_f16`,
|
||||
/// which dequantizes the weight to f16 once and runs cuBLAS GEMM.
|
||||
/// At or below this M the GGUF GEMV kernel inside
|
||||
/// `QMatMul::forward` wins (it operates on quantized blocks directly
|
||||
/// and accumulates in registers).
|
||||
///
|
||||
/// Empirical: at M=30 on Qwen3.6-27B / RTX 5090, forward_via_f16 was
|
||||
/// slightly *slower* than the GGUF GEMV kernel — the per-call dequant
|
||||
/// cost (~30 MB f16 written to global memory per linear × ~480 calls
|
||||
/// per prefill) eats the cuBLAS GEMM speedup at small M. The
|
||||
/// crossover where the GEMM scaling actually beats the fixed dequant
|
||||
/// tax sits well above M=8.
|
||||
///
|
||||
/// 64 is a conservative crossover that keeps short-prompt prefills
|
||||
/// on the GGUF kernel (where the per-call cost is comparable to the
|
||||
/// f16 path but the dequant tax is zero) and only activates the
|
||||
/// dequant-then-GEMM path for long prefills where the GEMM size
|
||||
/// makes amortising worth it. A proper fix is either a dequant
|
||||
/// cache or a fused dequant+gemm cuda kernel — both larger projects.
|
||||
const QUANT_PREFILL_M_THRESHOLD: usize = 64;
|
||||
|
||||
impl Module for MaybeQuantLinear {
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
match self {
|
||||
Self::Plain(l) => l.forward(x),
|
||||
Self::Quant(qm) => {
|
||||
// Decode vs prefill split. `M` is the "rows of x" the
|
||||
// matmul will iterate over — every dim except the last
|
||||
// (which is in_features). For decode (`seq_len == 1`
|
||||
// with batch 1) M is 1; for prefill with L>>1 it's L
|
||||
// (or B*L).
|
||||
let dims = x.dims();
|
||||
let m: usize = dims.iter().take(dims.len() - 1).product();
|
||||
|
||||
if m > QUANT_PREFILL_M_THRESHOLD {
|
||||
// Prefill: dequantize the weight once into f16,
|
||||
// then run a real cuBLAS-backed GEMM. The cost of
|
||||
// the dequant is amortised across all M tokens.
|
||||
// `forward_via_f16` handles the dtype round-trip
|
||||
// internally (output matches input dtype).
|
||||
return qm.forward_via_f16(x);
|
||||
}
|
||||
|
||||
// Decode (M <= threshold): use the on-the-fly GGUF
|
||||
// GEMV kernel via `QMatMul::forward`. That kernel
|
||||
// requires f32 inputs (it accumulates in f32 from the
|
||||
// dequantized quant blocks); cast in/out at the
|
||||
// boundary.
|
||||
let in_dtype = x.dtype();
|
||||
let x_f32 = if in_dtype == candle_core::DType::F32 {
|
||||
x.clone()
|
||||
} else {
|
||||
x.to_dtype(candle_core::DType::F32)?
|
||||
};
|
||||
let y = qm.forward(&x_f32)?;
|
||||
if y.dtype() == in_dtype {
|
||||
Ok(y)
|
||||
} else {
|
||||
y.to_dtype(in_dtype)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to build a [`Shard`] hint for a given dimension.
|
||||
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
|
||||
Shard {
|
||||
dim,
|
||||
rank: rank as usize,
|
||||
world_size: world_size as usize,
|
||||
}
|
||||
}
|
||||
|
||||
/// Output-dim sharded linear (column-parallel). Holds a
|
||||
/// [`MaybeQuantLinear`] whose underlying weight is this rank's slice
|
||||
/// of the full `[out_features, in_features]` tensor along dim 0.
|
||||
pub struct ColumnParallelLinear {
|
||||
inner: MaybeQuantLinear,
|
||||
}
|
||||
|
||||
impl ColumnParallelLinear {
|
||||
/// Load this rank's column-parallel slice from a
|
||||
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
|
||||
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
|
||||
///
|
||||
/// Backward-compatible variant — no in-situ quantization. For
|
||||
/// quantized loads, use [`Self::load_with_quant`].
|
||||
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||
Self::load_with_quant(vb, rank, world_size, None)
|
||||
}
|
||||
|
||||
/// Like [`Self::load`] but quantizes the per-rank weight in-situ
|
||||
/// when `quant` is `Some(dtype)`. Saves ~3-5x vs bf16/f16.
|
||||
pub fn load_with_quant(
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get_with_hints((), "weight", shard(0, rank, world_size))
|
||||
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
|
||||
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||
.with_context(|| format!("wrap column-parallel '{}'", vb.prefix()))?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ColumnParallelLinear {
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Input-dim sharded linear (row-parallel).
|
||||
///
|
||||
/// Holds a sharded [`MaybeQuantLinear`] plus an `AllReduce` op the
|
||||
/// forward chains after the local matmul to recover the full activation.
|
||||
pub struct RowParallelLinear {
|
||||
inner: MaybeQuantLinear,
|
||||
#[cfg(feature = "cuda")]
|
||||
all_reduce: AllReduce,
|
||||
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
|
||||
/// is a pair: the column output is sharded, the row input is
|
||||
/// sharded, and the AllReduce gives back the full output. For
|
||||
/// `world_size = 1` the AllReduce is a no-op so we skip it.
|
||||
needs_reduce: bool,
|
||||
}
|
||||
|
||||
impl RowParallelLinear {
|
||||
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
|
||||
///
|
||||
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
|
||||
/// `AllReduce` runs against. On CPU builds the parameter is
|
||||
/// elided — forward returns the partial sum, which is the *wrong*
|
||||
/// answer for inference but lets us compile-check the model.
|
||||
///
|
||||
/// Backward-compatible variant — no in-situ quantization. For
|
||||
/// quantized loads, use [`Self::load_with_quant`].
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn load(
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||
) -> Result<Self> {
|
||||
Self::load_with_quant(vb, rank, world_size, comm, None)
|
||||
}
|
||||
|
||||
/// Like [`Self::load`] but quantizes the per-rank weight in-situ.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn load_with_quant(
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: std::sync::Arc<cudarc::nccl::Comm>,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
|
||||
Ok(Self {
|
||||
inner,
|
||||
all_reduce: AllReduce::new(comm),
|
||||
needs_reduce: world_size > 1,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||
Self::load_with_quant(vb, rank, world_size, None)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub fn load_with_quant(
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
quant: Option<GgmlDType>,
|
||||
) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get_with_hints((), "weight", shard(1, rank, world_size))
|
||||
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
|
||||
let inner = MaybeQuantLinear::from_weight(weight, quant)
|
||||
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
|
||||
Ok(Self {
|
||||
inner,
|
||||
needs_reduce: world_size > 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RowParallelLinear {
|
||||
/// Local matmul followed by an `AllReduce` (when `cuda` and
|
||||
/// `world_size > 1`). On CPU or single-rank, returns the partial
|
||||
/// output directly — which is *only* correct for `world_size == 1`.
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let local = self.inner.forward(x)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
if self.needs_reduce {
|
||||
return local.apply_op1_no_bwd(&self.all_reduce);
|
||||
}
|
||||
let _ = self.needs_reduce;
|
||||
Ok(local)
|
||||
}
|
||||
}
|
||||
678
crates/neuron/src/harness/tp/tp_qwen3.rs
Normal file
678
crates/neuron/src/harness/tp/tp_qwen3.rs
Normal file
@@ -0,0 +1,678 @@
|
||||
//! Tensor-parallel Qwen3 dense model.
|
||||
//!
|
||||
//! Mirrors `candle_transformers::models::qwen3` structurally, but with:
|
||||
//!
|
||||
//! - Attention's `q_proj` / `k_proj` / `v_proj` as
|
||||
//! [`ColumnParallelLinear`] (output sharded along the head dimension —
|
||||
//! per-rank `num_heads = total/world_size`, ditto for kv heads).
|
||||
//! - Attention's `o_proj` as [`RowParallelLinear`] (input sharded; the
|
||||
//! trailing `AllReduce` recovers the full activation).
|
||||
//! - MLP's `gate_proj` / `up_proj` as [`ColumnParallelLinear`] (sharded
|
||||
//! along `intermediate_size`).
|
||||
//! - MLP's `down_proj` as [`RowParallelLinear`].
|
||||
//! - `embed_tokens`, all `RmsNorm`s, and `lm_head` **replicated** on
|
||||
//! every rank. The per-rank duplicate weight is bounded and lets us
|
||||
//! skip the embedding all-gather and the lm-head column-shard +
|
||||
//! all-gather; both are pure latency optimisations that don't change
|
||||
//! correctness.
|
||||
//! - `kv_cache` holds the per-rank slice of K/V already (because they
|
||||
//! came out of a column-parallel projection). No cache resharding
|
||||
//! needed across ranks.
|
||||
//!
|
||||
//! Divisibility requirement, checked at load time:
|
||||
//!
|
||||
//! - `num_attention_heads % world_size == 0`
|
||||
//! - `num_key_value_heads % world_size == 0`
|
||||
//! - `intermediate_size % world_size == 0`
|
||||
//!
|
||||
//! Anything else bails — the safetensors slice would lose data otherwise.
|
||||
//! This is the same divisibility-bail pattern that landed in
|
||||
//! `EricLBuehler/mistral.rs` PR #2054.
|
||||
//!
|
||||
//! Replicated tensors (norms, embedding, lm_head) are loaded by asking
|
||||
//! the `ShardedVarBuilder` for the full tensor via `vb.get(shape, name)`
|
||||
//! — which defaults to `Shard { world_size: 1 }` and falls through to
|
||||
//! the unsharded backend path.
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use candle_nn::{Activation, Embedding, Linear, RmsNorm, kv_cache::ConcatKvCache};
|
||||
use candle_transformers::utils::repeat_kv;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use cudarc::nccl::Comm;
|
||||
|
||||
use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
|
||||
|
||||
pub use candle_transformers::models::qwen3::Config;
|
||||
|
||||
/// Replicated rotary-embedding lookup. Re-implementation of the
|
||||
/// `pub(crate)` candle equivalent — we can't reach into the upstream
|
||||
/// type, so the inv-freq / sin / cos construction lives here.
|
||||
pub(crate) struct Qwen3RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl Qwen3RotaryEmbedding {
|
||||
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let (_, _, seq_len, _) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper: load a replicated tensor by asking the ShardedVarBuilder for
|
||||
/// the full tensor (world_size=1 hint falls through to SimpleBackend).
|
||||
fn load_replicated<S: Into<candle_core::Shape>>(
|
||||
vb: &ShardedVarBuilder,
|
||||
shape: S,
|
||||
name: &str,
|
||||
) -> Result<Tensor> {
|
||||
vb.get(shape, name)
|
||||
.with_context(|| format!("load replicated '{}/{name}'", vb.prefix()))
|
||||
}
|
||||
|
||||
fn load_rms_norm(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<RmsNorm> {
|
||||
let weight = load_replicated(vb, size, "weight")?;
|
||||
Ok(RmsNorm::new(weight, eps))
|
||||
}
|
||||
|
||||
/// TP MLP. SwiGLU = `down(silu(gate(x)) * up(x))`.
|
||||
pub(crate) struct TpQwen3MLP {
|
||||
gate_proj: ColumnParallelLinear,
|
||||
up_proj: ColumnParallelLinear,
|
||||
down_proj: RowParallelLinear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl TpQwen3MLP {
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn load(
|
||||
cfg: &Config,
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
) -> Result<Self> {
|
||||
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||
bail!(
|
||||
"intermediate_size {} not divisible by world_size {}",
|
||||
cfg.intermediate_size,
|
||||
world_size
|
||||
);
|
||||
}
|
||||
Ok(Self {
|
||||
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
|
||||
bail!(
|
||||
"intermediate_size {} not divisible by world_size {}",
|
||||
cfg.intermediate_size,
|
||||
world_size
|
||||
);
|
||||
}
|
||||
Ok(Self {
|
||||
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
|
||||
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
|
||||
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TpQwen3MLP {
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = x.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
/// TP attention. Carries per-rank head counts and the q/k per-head
|
||||
/// RmsNorms (which are replicated and operate on a flattened B*H*L
|
||||
/// axis, so the same code path works irrespective of how H was split).
|
||||
pub(crate) struct TpQwen3Attention {
|
||||
q_proj: ColumnParallelLinear,
|
||||
k_proj: ColumnParallelLinear,
|
||||
v_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
q_norm: RmsNorm,
|
||||
k_norm: RmsNorm,
|
||||
local_num_heads: usize,
|
||||
local_num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
local_hidden_size: usize,
|
||||
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||
kv_cache: ConcatKvCache,
|
||||
}
|
||||
|
||||
impl TpQwen3Attention {
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn load(
|
||||
cfg: &Config,
|
||||
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
) -> Result<Self> {
|
||||
Self::load_inner(
|
||||
cfg,
|
||||
rotary_emb,
|
||||
vb,
|
||||
rank,
|
||||
world_size,
|
||||
#[cfg(feature = "cuda")]
|
||||
comm,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub fn load(
|
||||
cfg: &Config,
|
||||
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Self> {
|
||||
Self::load_inner(cfg, rotary_emb, vb, rank, world_size)
|
||||
}
|
||||
|
||||
fn load_inner(
|
||||
cfg: &Config,
|
||||
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
#[cfg(feature = "cuda")] comm: Arc<Comm>,
|
||||
) -> Result<Self> {
|
||||
if cfg.use_sliding_window {
|
||||
bail!("sliding window is not yet supported in the TP path");
|
||||
}
|
||||
if cfg.attention_bias {
|
||||
bail!("attention_bias=true is not supported by ColumnParallel/RowParallelLinear yet");
|
||||
}
|
||||
let ws = world_size as usize;
|
||||
if !cfg.num_attention_heads.is_multiple_of(ws) {
|
||||
bail!(
|
||||
"num_attention_heads {} not divisible by world_size {}",
|
||||
cfg.num_attention_heads,
|
||||
world_size
|
||||
);
|
||||
}
|
||||
if !cfg.num_key_value_heads.is_multiple_of(ws) {
|
||||
bail!(
|
||||
"num_key_value_heads {} not divisible by world_size {}",
|
||||
cfg.num_key_value_heads,
|
||||
world_size
|
||||
);
|
||||
}
|
||||
let head_dim = cfg.head_dim;
|
||||
let local_num_heads = cfg.num_attention_heads / ws;
|
||||
let local_num_kv_heads = cfg.num_key_value_heads / ws;
|
||||
let num_kv_groups = local_num_heads / local_num_kv_heads;
|
||||
let local_hidden_size = head_dim * local_num_heads;
|
||||
|
||||
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?;
|
||||
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?;
|
||||
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?;
|
||||
|
||||
let q_norm = load_rms_norm(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||
let k_norm = load_rms_norm(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||
|
||||
// dim=2 because we cat along the seq axis of (B, H, L, D) tensors.
|
||||
let kv_cache = ConcatKvCache::new(2);
|
||||
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
local_num_heads,
|
||||
local_num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
local_hidden_size,
|
||||
rotary_emb,
|
||||
kv_cache,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l, _) = x.dims3()?;
|
||||
|
||||
// 1. Projections (column-parallel → output is sharded).
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
// 2. Reshape: (B, L, H, D) → (B, H, L, D).
|
||||
let q = q
|
||||
.reshape((b, l, self.local_num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
// 3. Per-head RmsNorm (replicated weight, flat input).
|
||||
let q_flat = q.flatten(0, 2)?;
|
||||
let k_flat = k.flatten(0, 2)?;
|
||||
let q_flat = self.q_norm.forward(&q_flat)?;
|
||||
let k_flat = self.k_norm.forward(&k_flat)?;
|
||||
let q = q_flat.reshape((b, self.local_num_heads, l, self.head_dim))?;
|
||||
let k = k_flat.reshape((b, self.local_num_kv_heads, l, self.head_dim))?;
|
||||
|
||||
// 4. Rotary.
|
||||
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
|
||||
|
||||
// 5. Accumulate KV.
|
||||
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||
|
||||
// 6. GQA repeat_kv on the rank-local K/V.
|
||||
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
// 7. Attention scores.
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
if let Some(m) = attn_mask {
|
||||
scores = scores.broadcast_add(m)?;
|
||||
}
|
||||
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||
let ctx = probs.matmul(&v)?;
|
||||
|
||||
// 8. Output projection (row-parallel → AllReduce inside).
|
||||
ctx.transpose(1, 2)?
|
||||
.reshape((b, l, self.local_hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache.reset();
|
||||
}
|
||||
}
|
||||
|
||||
struct TpDecoderLayer {
|
||||
self_attn: TpQwen3Attention,
|
||||
mlp: TpQwen3MLP,
|
||||
ln1: RmsNorm,
|
||||
ln2: RmsNorm,
|
||||
}
|
||||
|
||||
impl TpDecoderLayer {
|
||||
#[cfg(feature = "cuda")]
|
||||
fn load(
|
||||
cfg: &Config,
|
||||
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
) -> Result<Self> {
|
||||
let self_attn = TpQwen3Attention::load(
|
||||
cfg,
|
||||
rotary_emb,
|
||||
&vb.pp("self_attn"),
|
||||
rank,
|
||||
world_size,
|
||||
comm.clone(),
|
||||
)?;
|
||||
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?;
|
||||
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
let ln2 = load_rms_norm(
|
||||
&vb.pp("post_attention_layernorm"),
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
ln1,
|
||||
ln2,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn load(
|
||||
cfg: &Config,
|
||||
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
) -> Result<Self> {
|
||||
let self_attn =
|
||||
TpQwen3Attention::load(cfg, rotary_emb, &vb.pp("self_attn"), rank, world_size)?;
|
||||
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?;
|
||||
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
let ln2 = load_rms_norm(
|
||||
&vb.pp("post_attention_layernorm"),
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
ln1,
|
||||
ln2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let h = self.ln1.forward(x)?;
|
||||
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||
let x = (x + h)?;
|
||||
let h2 = self.ln2.forward(&x)?;
|
||||
let h2 = h2.apply(&self.mlp)?;
|
||||
x + h2
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
/// Base TP Qwen3 transformer — embedding, decoder stack, final norm.
|
||||
/// The lm_head sits on top in [`TpQwen3ForCausalLM`].
|
||||
pub struct TpQwen3Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<TpDecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl TpQwen3Model {
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn load(
|
||||
cfg: &Config,
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
) -> Result<Self> {
|
||||
let dtype = vb.dtype();
|
||||
let device = vb.device().clone();
|
||||
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||
|
||||
let embed_vb = vb.pp("model.embed_tokens");
|
||||
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||
|
||||
let vb_l = vb.pp("model.layers");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
layers.push(TpDecoderLayer::load(
|
||||
cfg,
|
||||
rotary.clone(),
|
||||
&vb_l.pp(i),
|
||||
rank,
|
||||
world_size,
|
||||
comm.clone(),
|
||||
)?);
|
||||
}
|
||||
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||
let dtype = vb.dtype();
|
||||
let device = vb.device().clone();
|
||||
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||
|
||||
let embed_vb = vb.pp("model.embed_tokens");
|
||||
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||
|
||||
let vb_l = vb.pp("model.layers");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
layers.push(TpDecoderLayer::load(
|
||||
cfg,
|
||||
rotary.clone(),
|
||||
&vb_l.pp(i),
|
||||
rank,
|
||||
world_size,
|
||||
)?);
|
||||
}
|
||||
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embed_weight(&self) -> &Tensor {
|
||||
self.embed_tokens.embeddings()
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for l in &mut self.layers {
|
||||
l.clear_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let minf = f32::NEG_INFINITY;
|
||||
let mask: Vec<_> = (0..tgt)
|
||||
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
|
||||
let causal = if l == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.causal_mask(b, l, offset)?)
|
||||
};
|
||||
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
}
|
||||
}
|
||||
|
||||
/// TP Qwen3 with a (replicated) language-model head on top.
|
||||
pub struct TpQwen3ForCausalLM {
|
||||
base: TpQwen3Model,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl TpQwen3ForCausalLM {
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn load(
|
||||
cfg: &Config,
|
||||
vb: &ShardedVarBuilder,
|
||||
rank: u32,
|
||||
world_size: u32,
|
||||
comm: Arc<Comm>,
|
||||
) -> Result<Self> {
|
||||
let base = TpQwen3Model::load(cfg, vb, rank, world_size, comm)?;
|
||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||
let model = Self { base, lm_head };
|
||||
log_construction_complete(cfg, rank, world_size, model.device());
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
|
||||
let base = TpQwen3Model::load(cfg, vb, rank, world_size)?;
|
||||
let lm_head = build_lm_head(cfg, vb, &base)?;
|
||||
let model = Self { base, lm_head };
|
||||
log_construction_complete(cfg, rank, world_size, model.device());
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
let hidden = self.base.forward(input, offset)?;
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.base.clear_kv_cache();
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.base.device
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.base.dtype
|
||||
}
|
||||
}
|
||||
|
||||
fn build_lm_head(cfg: &Config, vb: &ShardedVarBuilder, base: &TpQwen3Model) -> Result<Linear> {
|
||||
if cfg.tie_word_embeddings {
|
||||
Ok(Linear::new(base.embed_weight().clone(), None))
|
||||
} else {
|
||||
let weight = load_replicated(
|
||||
&vb.pp("lm_head"),
|
||||
(cfg.vocab_size, cfg.hidden_size),
|
||||
"weight",
|
||||
)?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
}
|
||||
|
||||
/// VRAM accounting + config dump emitted at the end of
|
||||
/// `TpQwen3ForCausalLM::load`. Same intent as the Qwen3-Next variant
|
||||
/// in tp_qwen3_5.rs — surface the resolved hyperparameters and
|
||||
/// per-rank free VRAM in one line so an operator chasing an OOM or a
|
||||
/// numerical issue doesn't have to grep the per-layer load logs.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn log_construction_complete(cfg: &Config, rank: u32, world_size: u32, device: &Device) {
|
||||
use candle_core::cuda::cudarc::driver::result;
|
||||
use candle_core::cuda_backend::WrapErr;
|
||||
let (free_mb, total_mb) = if let Device::Cuda(dev) = device {
|
||||
if dev.cuda_stream().context().bind_to_thread().w().is_ok() {
|
||||
match result::mem_get_info() {
|
||||
Ok((free, total)) => (free / (1024 * 1024), total / (1024 * 1024)),
|
||||
Err(_) => (0, 0),
|
||||
}
|
||||
} else {
|
||||
(0, 0)
|
||||
}
|
||||
} else {
|
||||
(0, 0)
|
||||
};
|
||||
// Per-rank KV cache cost at one token: K + V × bf16. Vanilla
|
||||
// Qwen3 is dense attention end-to-end, so every layer
|
||||
// contributes. Knowing per-token bytes lets the operator estimate
|
||||
// headroom for a given prompt length before hitting an edge.
|
||||
let per_rank_num_kv_heads = (cfg.num_key_value_heads / world_size as usize).max(1);
|
||||
let kv_bytes_per_token_per_layer = per_rank_num_kv_heads * cfg.head_dim * 2 * 2;
|
||||
let kv_bytes_per_token = kv_bytes_per_token_per_layer * cfg.num_hidden_layers;
|
||||
tracing::info!(
|
||||
target: "neuron::tp::load",
|
||||
rank,
|
||||
world_size,
|
||||
free_mb,
|
||||
total_mb,
|
||||
vocab_size = cfg.vocab_size,
|
||||
hidden_size = cfg.hidden_size,
|
||||
num_hidden_layers = cfg.num_hidden_layers,
|
||||
num_attention_heads = cfg.num_attention_heads,
|
||||
num_key_value_heads = cfg.num_key_value_heads,
|
||||
head_dim = cfg.head_dim,
|
||||
max_position_embeddings = cfg.max_position_embeddings,
|
||||
per_rank_num_kv_heads,
|
||||
kv_bytes_per_token,
|
||||
"Qwen3 model construction complete"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn log_construction_complete(cfg: &Config, rank: u32, world_size: u32, _device: &Device) {
|
||||
let per_rank_num_kv_heads = (cfg.num_key_value_heads / world_size as usize).max(1);
|
||||
let kv_bytes_per_token_per_layer = per_rank_num_kv_heads * cfg.head_dim * 2 * 2;
|
||||
let kv_bytes_per_token = kv_bytes_per_token_per_layer * cfg.num_hidden_layers;
|
||||
tracing::info!(
|
||||
target: "neuron::tp::load",
|
||||
rank,
|
||||
world_size,
|
||||
vocab_size = cfg.vocab_size,
|
||||
hidden_size = cfg.hidden_size,
|
||||
num_hidden_layers = cfg.num_hidden_layers,
|
||||
num_attention_heads = cfg.num_attention_heads,
|
||||
num_key_value_heads = cfg.num_key_value_heads,
|
||||
head_dim = cfg.head_dim,
|
||||
max_position_embeddings = cfg.max_position_embeddings,
|
||||
per_rank_num_kv_heads,
|
||||
kv_bytes_per_token,
|
||||
"Qwen3 model construction complete"
|
||||
);
|
||||
}
|
||||
1207
crates/neuron/src/harness/tp/tp_qwen3_5.rs
Normal file
1207
crates/neuron/src/harness/tp/tp_qwen3_5.rs
Normal file
File diff suppressed because it is too large
Load Diff
502
crates/neuron/src/harness/tp/worker.rs
Normal file
502
crates/neuron/src/harness/tp/worker.rs
Normal file
@@ -0,0 +1,502 @@
|
||||
//! Entry point for `neuron --worker`.
|
||||
//!
|
||||
//! The worker reads one newline-delimited JSON `WorkerRequest` from
|
||||
//! stdin per loop iteration, dispatches synchronously, and writes
|
||||
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
|
||||
//! stderr so it doesn't collide with the RPC stream.
|
||||
//!
|
||||
//! NCCL operations (`Init`, `NcclSanityCheck`) and model lifecycle ops
|
||||
//! (`LoadDenseShard`, `GenerateStep`, `ClearKvCache`, `UnloadModel`)
|
||||
//! are real when built with the `cuda` feature; without it they reply
|
||||
//! with `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
|
||||
//! the difference between a misconfigured build and a genuine NCCL or
|
||||
//! model failure.
|
||||
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
|
||||
use super::nccl_state::NcclState;
|
||||
use super::rpc::{WorkerRequest, WorkerResponse};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use super::tp_qwen3::TpQwen3ForCausalLM;
|
||||
#[cfg(feature = "cuda")]
|
||||
use super::tp_qwen3_5::TpQwen3_5ForCausalLM;
|
||||
|
||||
/// Worker-side discriminator over the architectures we can load via
|
||||
/// `LoadDenseShard`. Mirrors `super::TpLeaderModel` on the leader
|
||||
/// side — the dispatch happens on the `model_type` extracted from the
|
||||
/// config JSON.
|
||||
#[cfg(feature = "cuda")]
|
||||
enum WorkerModel {
|
||||
Qwen3(TpQwen3ForCausalLM),
|
||||
Qwen3_5(TpQwen3_5ForCausalLM),
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
impl WorkerModel {
|
||||
fn forward(
|
||||
&mut self,
|
||||
input: &candle_core::Tensor,
|
||||
offset: usize,
|
||||
) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
WorkerModel::Qwen3(m) => m.forward(input, offset),
|
||||
WorkerModel::Qwen3_5(m) => m.forward(input, offset),
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
match self {
|
||||
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
|
||||
WorkerModel::Qwen3_5(m) => m.clear_kv_cache(),
|
||||
}
|
||||
}
|
||||
|
||||
fn device(&self) -> &candle_core::Device {
|
||||
match self {
|
||||
WorkerModel::Qwen3(m) => m.device(),
|
||||
WorkerModel::Qwen3_5(m) => m.device(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct WorkerConfig {
|
||||
pub rank: u32,
|
||||
pub world_size: u32,
|
||||
pub cuda_device: u32,
|
||||
}
|
||||
|
||||
/// Drive the worker RPC loop until `Shutdown` or EOF on stdin.
|
||||
pub async fn run(config: WorkerConfig) -> Result<()> {
|
||||
tracing::info!(
|
||||
rank = config.rank,
|
||||
world_size = config.world_size,
|
||||
cuda_device = config.cuda_device,
|
||||
"tp worker starting"
|
||||
);
|
||||
|
||||
let mut state = WorkerState::new(config);
|
||||
let stdin = tokio::io::stdin();
|
||||
let mut reader = BufReader::new(stdin).lines();
|
||||
let mut stdout = tokio::io::stdout();
|
||||
|
||||
while let Some(line) = reader.next_line().await? {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let req: WorkerRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let resp = WorkerResponse::Error {
|
||||
kind: "bad_request".into(),
|
||||
message: format!("parse {line:?}: {e}"),
|
||||
};
|
||||
write_response(&mut stdout, &resp).await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let resp = state.handle(req).await;
|
||||
let is_bye = matches!(resp, WorkerResponse::Bye);
|
||||
write_response(&mut stdout, &resp).await?;
|
||||
if is_bye {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(rank = config.rank, "tp worker exiting");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -> Result<()> {
|
||||
let mut line = serde_json::to_string(resp)?;
|
||||
line.push('\n');
|
||||
stdout.write_all(line.as_bytes()).await?;
|
||||
stdout.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// One rank's local state. Owns the rank's NCCL communicator (via
|
||||
/// `NcclState`) and the rank's shard of every loaded model.
|
||||
struct WorkerState {
|
||||
config: WorkerConfig,
|
||||
nccl: NcclState,
|
||||
/// Loaded model shards keyed by `model_id`. Each entry wraps the
|
||||
/// rank's TP architecture handle (Qwen3 or Qwen3-Next) — the
|
||||
/// column/row-parallel layers hold an `Arc<Comm>` cloned from
|
||||
/// `nccl`. Cuda-only: the underlying types reference cudarc types
|
||||
/// that don't exist without the cuda feature.
|
||||
#[cfg(feature = "cuda")]
|
||||
models: HashMap<String, WorkerModel>,
|
||||
/// Placeholder so the non-cuda build keeps the same field name set
|
||||
/// and `WorkerState::new` reads the same on both.
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(dead_code)]
|
||||
models: HashMap<String, ()>,
|
||||
}
|
||||
|
||||
impl WorkerState {
|
||||
fn new(config: WorkerConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
nccl: NcclState::new(),
|
||||
models: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle(&mut self, req: WorkerRequest) -> WorkerResponse {
|
||||
match req {
|
||||
WorkerRequest::Ping => WorkerResponse::Pong {
|
||||
rank: self.config.rank,
|
||||
world_size: self.config.world_size,
|
||||
cuda_device: self.config.cuda_device,
|
||||
},
|
||||
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
|
||||
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
|
||||
WorkerRequest::LoadDenseShard {
|
||||
model_id,
|
||||
config_json,
|
||||
safetensors_paths,
|
||||
quant,
|
||||
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant),
|
||||
WorkerRequest::GenerateStep {
|
||||
model_id,
|
||||
tokens,
|
||||
offset,
|
||||
} => self.handle_generate_step(&model_id, tokens, offset),
|
||||
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
||||
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
||||
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn handle_load_dense_shard(
|
||||
&mut self,
|
||||
model_id: String,
|
||||
config_json: String,
|
||||
safetensors_paths: Vec<String>,
|
||||
quant: Option<String>,
|
||||
) -> WorkerResponse {
|
||||
use crate::harness::arch::qwen3_5 as qwen3_5_arch;
|
||||
use candle_core::{DType, Device};
|
||||
use candle_nn::var_builder::ShardedSafeTensors;
|
||||
use candle_transformers::models::qwen3 as qwen3_dense;
|
||||
use std::path::PathBuf;
|
||||
|
||||
let quant_dtype = match parse_quant_string(quant.as_deref()) {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "bad_request".into(),
|
||||
message: format!("parse quant: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if self.models.contains_key(&model_id) {
|
||||
return WorkerResponse::Error {
|
||||
kind: "already_loaded".into(),
|
||||
message: format!("model '{model_id}' already loaded on this rank"),
|
||||
};
|
||||
}
|
||||
let comm = match self.nccl.comm() {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "nccl_not_initialised".into(),
|
||||
message: "LoadDenseShard requires Init to have completed first".into(),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Peek at model_type so we know which architecture to build.
|
||||
let model_type = serde_json::from_str::<serde_json::Value>(&config_json)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("model_type"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let device = match Device::new_cuda(self.config.cuda_device as usize) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "cuda_unavailable".into(),
|
||||
message: format!("Device::new_cuda({}) failed: {e}", self.config.cuda_device),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let paths: Vec<PathBuf> = safetensors_paths.into_iter().map(PathBuf::from).collect();
|
||||
// SAFETY: same invariant as the single-GPU dense path — the HF
|
||||
// cache files are treated as immutable while the mmap is held.
|
||||
let vb = match unsafe { ShardedSafeTensors::var_builder(&paths, DType::BF16, &device) } {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "load_failed".into(),
|
||||
message: format!("ShardedSafeTensors::var_builder: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
// Separate mmap of the same paths for the direct fused-region
|
||||
// loader in `fused_load`. Linux's page cache shares the
|
||||
// underlying pages between the two mmaps; the cost is one
|
||||
// extra set of safetensors-header parses.
|
||||
let mmap = match unsafe { candle_core::safetensors::MmapedSafetensors::multi(&paths) } {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "load_failed".into(),
|
||||
message: format!("MmapedSafetensors::multi: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let loaded = match model_type.as_str() {
|
||||
"qwen3" => {
|
||||
let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "bad_request".into(),
|
||||
message: format!("parse Qwen3 Config JSON: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
match TpQwen3ForCausalLM::load(
|
||||
&cfg,
|
||||
&vb,
|
||||
self.config.rank,
|
||||
self.config.world_size,
|
||||
comm,
|
||||
) {
|
||||
Ok(m) => WorkerModel::Qwen3(m),
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "load_failed".into(),
|
||||
message: format!("TpQwen3ForCausalLM::load: {e:#}"),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
"qwen3_5" => {
|
||||
let cfg: qwen3_5_arch::Config = match serde_json::from_str(&config_json) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "bad_request".into(),
|
||||
message: format!("parse Qwen3-Next Config JSON: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
match TpQwen3_5ForCausalLM::load(
|
||||
cfg,
|
||||
&vb,
|
||||
&mmap,
|
||||
self.config.rank,
|
||||
self.config.world_size,
|
||||
comm,
|
||||
quant_dtype,
|
||||
) {
|
||||
Ok(m) => WorkerModel::Qwen3_5(m),
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "load_failed".into(),
|
||||
message: format!("TpQwen3_5ForCausalLM::load: {e:#}"),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
other => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "unsupported_arch".into(),
|
||||
message: format!(
|
||||
"worker: unsupported model_type '{other}' (supported: qwen3, qwen3_5)"
|
||||
),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
self.models.insert(model_id.clone(), loaded);
|
||||
tracing::info!(
|
||||
rank = self.config.rank,
|
||||
model = %model_id,
|
||||
model_type = %model_type,
|
||||
"loaded TP shard"
|
||||
);
|
||||
WorkerResponse::LoadDenseShardOk
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn handle_load_dense_shard(
|
||||
&mut self,
|
||||
_model_id: String,
|
||||
_config_json: String,
|
||||
_safetensors_paths: Vec<String>,
|
||||
_quant: Option<String>,
|
||||
) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "LoadDenseShard requires --features cuda".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn handle_generate_step(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> WorkerResponse {
|
||||
use candle_core::Tensor;
|
||||
|
||||
let Some(model) = self.models.get_mut(model_id) else {
|
||||
return WorkerResponse::Error {
|
||||
kind: "model_not_loaded".into(),
|
||||
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||
};
|
||||
};
|
||||
let device = model.device().clone();
|
||||
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
return WorkerResponse::Error {
|
||||
kind: "forward_failed".into(),
|
||||
message: format!("build input tensor: {e}"),
|
||||
};
|
||||
}
|
||||
};
|
||||
let start = std::time::Instant::now();
|
||||
tracing::debug!(
|
||||
rank = self.config.rank,
|
||||
model = %model_id,
|
||||
tokens = tokens.len(),
|
||||
offset,
|
||||
"worker GenerateStep: forward starting"
|
||||
);
|
||||
// Drop the resulting logits — the leader uses its own copy from
|
||||
// rank 0. The forward's value here is the NCCL collectives it
|
||||
// issues, which let the leader's rank-0 forward make progress.
|
||||
if let Err(e) = model.forward(&input, offset) {
|
||||
tracing::warn!(
|
||||
rank = self.config.rank,
|
||||
model = %model_id,
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
error = %e,
|
||||
"worker GenerateStep: forward failed"
|
||||
);
|
||||
return WorkerResponse::Error {
|
||||
kind: "forward_failed".into(),
|
||||
message: format!("TP forward: {e}"),
|
||||
};
|
||||
}
|
||||
tracing::debug!(
|
||||
rank = self.config.rank,
|
||||
model = %model_id,
|
||||
elapsed_ms = start.elapsed().as_millis(),
|
||||
"worker GenerateStep: forward done"
|
||||
);
|
||||
WorkerResponse::GenerateStepOk
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn handle_generate_step(
|
||||
&mut self,
|
||||
_model_id: &str,
|
||||
_tokens: Vec<u32>,
|
||||
_offset: usize,
|
||||
) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "GenerateStep requires --features cuda".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
|
||||
let Some(model) = self.models.get_mut(model_id) else {
|
||||
return WorkerResponse::Error {
|
||||
kind: "model_not_loaded".into(),
|
||||
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||
};
|
||||
};
|
||||
model.clear_kv_cache();
|
||||
WorkerResponse::KvCacheCleared
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn handle_clear_kv_cache(&mut self, _model_id: &str) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "ClearKvCache requires --features cuda".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn handle_unload_model(&mut self, model_id: &str) -> WorkerResponse {
|
||||
if self.models.remove(model_id).is_none() {
|
||||
return WorkerResponse::Error {
|
||||
kind: "model_not_loaded".into(),
|
||||
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||
};
|
||||
}
|
||||
tracing::info!(rank = self.config.rank, model = %model_id, "unloaded TP shard");
|
||||
WorkerResponse::Unloaded
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn handle_unload_model(&mut self, _model_id: &str) -> WorkerResponse {
|
||||
WorkerResponse::Error {
|
||||
kind: "cuda_feature_not_enabled".into(),
|
||||
message: "UnloadModel requires --features cuda".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the
|
||||
/// common ggml format names (case-insensitive). `None` and `Some("")`
|
||||
/// both map to "no quantization".
|
||||
///
|
||||
/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`,
|
||||
/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`,
|
||||
/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore
|
||||
/// is optional and the prefix is case-insensitive.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub(crate) fn parse_quant_string(
|
||||
s: Option<&str>,
|
||||
) -> anyhow::Result<Option<candle_core::quantized::GgmlDType>> {
|
||||
use candle_core::quantized::GgmlDType;
|
||||
let s = match s {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let normalised = s.to_ascii_lowercase().replace('_', "");
|
||||
let dtype = match normalised.as_str() {
|
||||
"q40" => GgmlDType::Q4_0,
|
||||
"q41" => GgmlDType::Q4_1,
|
||||
"q50" => GgmlDType::Q5_0,
|
||||
"q51" => GgmlDType::Q5_1,
|
||||
"q80" => GgmlDType::Q8_0,
|
||||
"q81" => GgmlDType::Q8_1,
|
||||
"q2k" => GgmlDType::Q2K,
|
||||
"q3k" => GgmlDType::Q3K,
|
||||
"q4k" | "q4km" => GgmlDType::Q4K,
|
||||
"q5k" | "q5km" => GgmlDType::Q5K,
|
||||
"q6k" => GgmlDType::Q6K,
|
||||
"q8k" => GgmlDType::Q8K,
|
||||
"f16" => GgmlDType::F16,
|
||||
"bf16" => GgmlDType::BF16,
|
||||
"f32" => GgmlDType::F32,
|
||||
other => anyhow::bail!(
|
||||
"unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \
|
||||
q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)"
|
||||
),
|
||||
};
|
||||
Ok(Some(dtype))
|
||||
}
|
||||
@@ -24,6 +24,12 @@ impl HealthCache {
|
||||
inner: RwLock::new(HealthResponse {
|
||||
uptime_secs: 0,
|
||||
devices: vec![],
|
||||
// The cache only owns the device-state half of /health;
|
||||
// the api handler overlays activation from the tracker.
|
||||
// Initialise with the default (Ready, empty lists) so a
|
||||
// direct read from the cache stays a well-typed
|
||||
// HealthResponse on the wire.
|
||||
activation: Default::default(),
|
||||
}),
|
||||
has_gpus: RwLock::new(false),
|
||||
}
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
pub mod activation;
|
||||
pub mod api;
|
||||
pub mod config;
|
||||
pub mod cuda;
|
||||
pub mod discovery;
|
||||
pub mod harness;
|
||||
pub mod health;
|
||||
pub mod startup;
|
||||
pub mod wire;
|
||||
|
||||
@@ -1,21 +1,66 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health};
|
||||
use neuron::{
|
||||
activation, api,
|
||||
config::NeuronConfig,
|
||||
discovery,
|
||||
harness::{HarnessRegistry, tp},
|
||||
health, startup,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
/// Top-level CLI. The same binary runs as either the public neuron
|
||||
/// daemon (default), a tensor-parallel worker subprocess (when
|
||||
/// `--worker` is set, spawned by the leader on the same host), or a
|
||||
/// one-shot TP NCCL handshake check (when `--tp-smoke` is set).
|
||||
#[derive(Parser)]
|
||||
#[command(name = "neuron")]
|
||||
#[command(about = "Per-node daemon for cortex inference clusters")]
|
||||
#[command(version)]
|
||||
struct Args {
|
||||
/// Port to listen on (overrides config file).
|
||||
/// Run in tensor-parallel worker mode. The leader process spawns
|
||||
/// one of these per non-zero NCCL rank and drives it over
|
||||
/// newline-delimited JSON on stdin/stdout. Worker mode skips
|
||||
/// discovery, the HTTP listener, and the health poller — it's a
|
||||
/// pure RPC loop.
|
||||
#[arg(long, default_value_t = false)]
|
||||
worker: bool,
|
||||
|
||||
/// Run a one-shot TP smoke test: spawn `--tp-size - 1` worker
|
||||
/// subprocesses on `--cuda-devices`, build the NCCL communicator,
|
||||
/// run an `AllReduce` sanity check across every rank, and exit.
|
||||
/// Used to validate the TP plumbing in isolation from model load
|
||||
/// and inference. Diagnostic-only — not exposed through the daemon
|
||||
/// HTTP API.
|
||||
#[arg(long, default_value_t = false)]
|
||||
tp_smoke: bool,
|
||||
|
||||
/// NCCL rank for worker mode. Ignored when `--worker` is not set.
|
||||
#[arg(long, default_value_t = 0)]
|
||||
rank: u32,
|
||||
|
||||
/// Total NCCL world size for worker mode or TP smoke mode.
|
||||
#[arg(long, default_value_t = 1)]
|
||||
tp_size: u32,
|
||||
|
||||
/// CUDA device index for worker mode. Ignored when `--worker` is
|
||||
/// not set.
|
||||
#[arg(long, default_value_t = 0)]
|
||||
cuda_device: u32,
|
||||
|
||||
/// Comma-separated CUDA device indices for TP smoke mode (one per
|
||||
/// rank, starting with rank 0). Must have `tp_size` entries.
|
||||
#[arg(long, value_delimiter = ',')]
|
||||
cuda_devices: Vec<u32>,
|
||||
|
||||
/// Port to listen on (overrides config file). Daemon mode only.
|
||||
#[arg(short, long)]
|
||||
port: Option<u16>,
|
||||
|
||||
/// Path to the neuron config file.
|
||||
/// Path to the neuron config file. Daemon mode only.
|
||||
#[arg(short, long, default_value = "neuron.toml")]
|
||||
config: String,
|
||||
}
|
||||
@@ -23,20 +68,99 @@ struct Args {
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
if args.worker {
|
||||
return tp::worker::run(tp::worker::WorkerConfig {
|
||||
rank: args.rank,
|
||||
world_size: args.tp_size,
|
||||
cuda_device: args.cuda_device,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
if args.tp_smoke {
|
||||
return tp_smoke(args.tp_size, args.cuda_devices).await;
|
||||
}
|
||||
|
||||
daemon(args).await
|
||||
}
|
||||
|
||||
/// One-shot tensor-parallel handshake. Spawns N-1 worker subprocesses
|
||||
/// (rank 0 stays in this process), builds the NCCL communicator across
|
||||
/// the full world, runs an AllReduce sanity check, and shuts everyone
|
||||
/// down. Output is plain log lines on stderr + a final summary on
|
||||
/// stdout in `key=value` form so an outer script can parse it.
|
||||
async fn tp_smoke(tp_size: u32, cuda_devices: Vec<u32>) -> Result<()> {
|
||||
if tp_size < 2 {
|
||||
anyhow::bail!("--tp-size must be at least 2 (got {tp_size})");
|
||||
}
|
||||
if cuda_devices.len() as u32 != tp_size {
|
||||
anyhow::bail!(
|
||||
"--cuda-devices must list exactly {tp_size} entries (got {})",
|
||||
cuda_devices.len()
|
||||
);
|
||||
}
|
||||
|
||||
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
||||
let leader_device = cuda_devices[0];
|
||||
|
||||
tracing::info!(
|
||||
tp_size,
|
||||
?cuda_devices,
|
||||
binary = %exe.display(),
|
||||
"tp-smoke: spawning worker pool"
|
||||
);
|
||||
// tp_smoke is a diagnostic tool; spawn the leader's device worker
|
||||
// directly. (In the daemon path, CandleHarness::ensure_device_worker
|
||||
// caches one per device.)
|
||||
let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(leader_device)
|
||||
.context("spawn leader device worker for tp-smoke")?;
|
||||
let mut pool =
|
||||
tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices, leader_worker.clone()).await?;
|
||||
|
||||
tracing::info!("tp-smoke: pinging every worker");
|
||||
let pongs = pool.ping_all().await?;
|
||||
for p in &pongs {
|
||||
tracing::info!(?p, "tp-smoke: pong");
|
||||
}
|
||||
|
||||
tracing::info!(leader_device, "tp-smoke: initialising NCCL");
|
||||
pool.init_nccl(leader_device).await?;
|
||||
|
||||
tracing::info!("tp-smoke: running AllReduce sanity check");
|
||||
pool.nccl_sanity_check().await?;
|
||||
|
||||
tracing::info!("tp-smoke: shutting down pool");
|
||||
pool.shutdown().await?;
|
||||
|
||||
println!("status=ok");
|
||||
println!("tp_size={tp_size}");
|
||||
println!(
|
||||
"cuda_devices={}",
|
||||
cuda_devices
|
||||
.iter()
|
||||
.map(|d| d.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn daemon(args: Args) -> Result<()> {
|
||||
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
|
||||
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
|
||||
NeuronConfig::default()
|
||||
});
|
||||
|
||||
let port = args.port.unwrap_or(cfg.port);
|
||||
let bind_url = format!("http://localhost:{port}");
|
||||
let start_time = Instant::now();
|
||||
|
||||
tracing::info!("running hardware discovery");
|
||||
@@ -47,9 +171,12 @@ async fn main() -> Result<()> {
|
||||
"discovery complete"
|
||||
);
|
||||
|
||||
// Build harness registry from config.
|
||||
let registry = HarnessRegistry::from_configs(&cfg.harnesses);
|
||||
// Build harness registry from config. In-process harnesses (candle)
|
||||
// need to know neuron's own bind URL so they can return it from
|
||||
// inference_endpoint.
|
||||
let registry = HarnessRegistry::from_configs(&cfg.harnesses, &bind_url, &cfg.harness);
|
||||
discovery_result.harnesses = registry.names();
|
||||
let candle = registry.candle();
|
||||
|
||||
let health_cache = Arc::new(health::HealthCache::new());
|
||||
health_cache
|
||||
@@ -61,17 +188,64 @@ async fn main() -> Result<()> {
|
||||
poller_cache.poll_loop(start_time).await;
|
||||
});
|
||||
|
||||
// Track pre-warm progress so `/health` can tell callers whether
|
||||
// configured default_models are still loading. Primed with the
|
||||
// pending list now; the spawned task below flips entries through
|
||||
// in_progress → completed/failed and finally toggles state=ready.
|
||||
let activation = Arc::new(activation::ActivationTracker::new(&cfg.default_models));
|
||||
|
||||
let state = Arc::new(api::NeuronState {
|
||||
discovery: discovery_result,
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle,
|
||||
activation: Arc::clone(&activation),
|
||||
});
|
||||
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
// Bind the HTTP listener BEFORE kicking off default_models loading.
|
||||
// Previously load_default_models ran synchronously on this task,
|
||||
// which delayed the bind by minutes for big TP models and made the
|
||||
// host look down to anything probing `/health` during pre-warm.
|
||||
// The pre-warm task runs in the background instead — `/health`
|
||||
// surfaces its progress via the activation field.
|
||||
let app = api::neuron_routes().with_state(Arc::clone(&state));
|
||||
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
|
||||
tracing::info!("neuron listening on {addr}");
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
tracing::info!("neuron listening on {addr}");
|
||||
|
||||
Ok(())
|
||||
if !cfg.default_models.is_empty() {
|
||||
let state_for_prewarm = Arc::clone(&state);
|
||||
let default_models = cfg.default_models.clone();
|
||||
tokio::spawn(async move {
|
||||
// Read lock held for the whole pre-warm run. The unload
|
||||
// path takes the same read lock per call (no writers) and
|
||||
// serialises through the candle harness's own internal
|
||||
// mutex, so concurrent on-demand loads and pre-warm loads
|
||||
// do not race on the same model.
|
||||
let registry = state_for_prewarm.registry.read().await;
|
||||
startup::load_default_models(®istry, &default_models, &state_for_prewarm.activation)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(startup::shutdown_signal())
|
||||
.await?;
|
||||
|
||||
// Deactivation: serve has returned (graceful shutdown signal
|
||||
// received and connections drained). Release CUDA contexts / VRAM
|
||||
// by unloading every model before exiting; systemd's TimeoutStopSec
|
||||
// bounds how long this phase may take.
|
||||
let registry = state.registry.read().await;
|
||||
startup::unload_all_models(®istry).await;
|
||||
tracing::info!("shutdown complete");
|
||||
// Fast-exit instead of returning. Returning lets `#[tokio::main]`
|
||||
// drop the runtime, which in turn waits on the blocking thread
|
||||
// pool to drain. After a CUDA driver error (OOM → illegal address)
|
||||
// a spawn_blocking thread can be wedged inside `cuCtxGetCurrent`,
|
||||
// and tokio's drain has no timeout. systemd then SIGABRTs us and
|
||||
// dumps core. Skipping the drain hands the OS a clean exit code;
|
||||
// the OS reaps the stuck threads. See the 2026-05-26 incident
|
||||
// captured under "Stack trace of thread 2951308" in the journal.
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
176
crates/neuron/src/startup.rs
Normal file
176
crates/neuron/src/startup.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
//! Activation- and deactivation-time orchestration.
|
||||
//!
|
||||
//! Wired from `main.rs` around the HTTP listener — activation runs
|
||||
//! before bind, deactivation runs after axum returns from its
|
||||
//! graceful-shutdown future. Kept in its own module so the logic is
|
||||
//! unit-testable without spinning up a full neuron process.
|
||||
|
||||
use crate::activation::ActivationTracker;
|
||||
use crate::harness::HarnessRegistry;
|
||||
use crate::harness::preflight::PreflightError;
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::signal;
|
||||
|
||||
/// Maximum time we wait on a single `unload_model` call during
|
||||
/// shutdown. The TP unload path tries `Arc::try_unwrap`, which fails
|
||||
/// fast when an inference is in flight, so a healthy unload returns
|
||||
/// in milliseconds. The timeout exists to bound a *future* unload
|
||||
/// path that might genuinely block on a stuck worker, so a single
|
||||
/// wedged model can't burn the whole systemd TimeoutStopSec window.
|
||||
const UNLOAD_TIMEOUT: Duration = Duration::from_secs(20);
|
||||
|
||||
/// Load each spec sequentially against the registry, treating
|
||||
/// individual failures as warnings rather than fatal errors.
|
||||
///
|
||||
/// VRAM contention makes parallel loads risky; the sequential path is
|
||||
/// boring but correct. The function logs elapsed time per load and
|
||||
/// updates `activation` so the `/health` endpoint can tell callers
|
||||
/// which models are still pre-warming. Caller is expected to run this
|
||||
/// in a background `tokio::spawn` task — the HTTP listener binds
|
||||
/// independently so the host is reachable during the pre-warm window.
|
||||
pub async fn load_default_models(
|
||||
registry: &HarnessRegistry,
|
||||
specs: &[ModelSpec],
|
||||
activation: &ActivationTracker,
|
||||
) {
|
||||
if specs.is_empty() {
|
||||
activation.mark_ready().await;
|
||||
return;
|
||||
}
|
||||
tracing::info!(count = specs.len(), "loading default models");
|
||||
for spec in specs {
|
||||
let start = Instant::now();
|
||||
activation.start_loading(&spec.model_id).await;
|
||||
match registry.load_model(spec).await {
|
||||
Ok(()) => {
|
||||
activation.complete_loading(&spec.model_id).await;
|
||||
tracing::info!(
|
||||
model = %spec.model_id,
|
||||
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||
"loaded default model"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
let rendered = format!("{e:#}");
|
||||
activation.fail_loading(&spec.model_id, &rendered).await;
|
||||
// When the underlying failure is a preflight rejection,
|
||||
// pull the structured fields out so journalctl shows
|
||||
// `reason=tp_requires_safetensors detail="..."` instead
|
||||
// of an opaque "fetch config.json … 404". The operator
|
||||
// can act on the structured form directly.
|
||||
if let Some(pf) = e.downcast_ref::<PreflightError>() {
|
||||
tracing::warn!(
|
||||
model = %spec.model_id,
|
||||
reason = preflight_kind(pf),
|
||||
detail = %pf,
|
||||
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||
"failed to load default model, continuing"
|
||||
);
|
||||
} else {
|
||||
tracing::warn!(
|
||||
model = %spec.model_id,
|
||||
error = %rendered,
|
||||
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||
"failed to load default model, continuing"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
activation.mark_ready().await;
|
||||
}
|
||||
|
||||
/// Short kebab-case tag for a preflight failure. Used as a structured
|
||||
/// log field so journalctl filtering can match on the failure class
|
||||
/// (`reason=tp_requires_safetensors`, `reason=quant_not_found`, etc.).
|
||||
fn preflight_kind(err: &PreflightError) -> &'static str {
|
||||
match err {
|
||||
PreflightError::RepoFetchFailed { .. } => "repo_fetch_failed",
|
||||
PreflightError::EmptyRepo { .. } => "empty_repo",
|
||||
PreflightError::TpRequiresSafetensors { .. } => "tp_requires_safetensors",
|
||||
PreflightError::QuantNotFound { .. } => "quant_not_found",
|
||||
}
|
||||
}
|
||||
|
||||
/// Future that resolves on SIGINT (Ctrl-C) or SIGTERM (systemd stop).
|
||||
///
|
||||
/// Wired into `axum::serve(...).with_graceful_shutdown(shutdown_signal())`
|
||||
/// so the HTTP listener stops accepting new connections, lets in-flight
|
||||
/// requests drain, and then yields control back to main for cleanup.
|
||||
pub async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c().await.ok();
|
||||
};
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("install SIGTERM handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
tokio::select! {
|
||||
_ = ctrl_c => tracing::info!("received SIGINT, shutting down"),
|
||||
_ = terminate => tracing::info!("received SIGTERM, shutting down"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Unload every model currently registered. Called from `main.rs` after
|
||||
/// axum's graceful shutdown future resolves, so CUDA contexts and VRAM
|
||||
/// are released before the process exits rather than left to the OS to
|
||||
/// reclaim. Per-model failures are logged and skipped — keep cleanup
|
||||
/// going even when one harness is unhealthy.
|
||||
pub async fn unload_all_models(registry: &HarnessRegistry) {
|
||||
let listed = match registry.list_all_models().await {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to list models during shutdown");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if listed.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!(count = listed.len(), "unloading models for shutdown");
|
||||
let mut stuck = 0;
|
||||
for model in listed {
|
||||
let start = Instant::now();
|
||||
match tokio::time::timeout(UNLOAD_TIMEOUT, registry.unload_model(&model.id)).await {
|
||||
Ok(Ok(())) => tracing::info!(
|
||||
model = %model.id,
|
||||
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||
"unloaded"
|
||||
),
|
||||
// Most common shape today: TP unload bails because an
|
||||
// inference is still mid-flight (the spawned task holds
|
||||
// an `Arc<TpLoadedModel>` clone). Promoted from warn to
|
||||
// error and tagged with the request-state so the operator
|
||||
// can correlate with the chat_completion logs above.
|
||||
Ok(Err(e)) => {
|
||||
stuck += 1;
|
||||
tracing::error!(
|
||||
model = %model.id,
|
||||
error = %e,
|
||||
elapsed_ms = start.elapsed().as_millis() as u64,
|
||||
"unload failed during shutdown"
|
||||
);
|
||||
}
|
||||
Err(_) => {
|
||||
stuck += 1;
|
||||
tracing::error!(
|
||||
model = %model.id,
|
||||
timeout_secs = UNLOAD_TIMEOUT.as_secs(),
|
||||
"unload timed out during shutdown, continuing"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
if stuck > 0 {
|
||||
tracing::error!(
|
||||
stuck,
|
||||
"shutdown leaving {stuck} model(s) loaded; VRAM will be \
|
||||
reclaimed by the OS on process exit"
|
||||
);
|
||||
}
|
||||
}
|
||||
306
crates/neuron/src/wire/event.rs
Normal file
306
crates/neuron/src/wire/event.rs
Normal file
@@ -0,0 +1,306 @@
|
||||
//! Format-agnostic inference event stream.
|
||||
//!
|
||||
//! The candle harness emits a sequence of these for every streaming
|
||||
//! request. Wire-format projections in sibling modules
|
||||
//! ([`super::openai_chat`], the eventual `openai_responses` /
|
||||
//! `anthropic_messages` projections) read this stream and produce
|
||||
//! the chunks / events their HTTP clients expect.
|
||||
//!
|
||||
//! Design notes:
|
||||
//!
|
||||
//! - [`Start`] carries no token of its own. It only signals "the
|
||||
//! model has accepted the prompt and is about to begin emitting
|
||||
//! text". OpenAI chat materialises this as a `role: assistant`
|
||||
//! chunk; OpenAI Responses as the `response.created` +
|
||||
//! `response.output_item.added` pair; Anthropic as
|
||||
//! `message_start`. All three of those would otherwise have to
|
||||
//! peek at the *first* token to know when to emit, which couples
|
||||
//! the wire layer to the producer's pacing.
|
||||
//! - [`TextDelta`] is *visible* output. Reasoning / `<think>`
|
||||
//! blocks go through a future [`ReasoningDelta`] variant once
|
||||
//! the harness learns to split them (today they pass through as
|
||||
//! plain text inside `TextDelta`; helexa-acp picks them apart on
|
||||
//! the consumer side).
|
||||
//! - [`Finish`] is the only place a stream is allowed to end
|
||||
//! cleanly. Projections rely on this to emit final usage
|
||||
//! bookkeeping; absence means the producer crashed and the
|
||||
//! consumer should treat the stream as truncated.
|
||||
//!
|
||||
//! [`Start`]: InferenceEvent::Start
|
||||
//! [`TextDelta`]: InferenceEvent::TextDelta
|
||||
//! [`Finish`]: InferenceEvent::Finish
|
||||
|
||||
/// One unit of output from the inference loop.
|
||||
///
|
||||
/// Producers send these on an `mpsc::Sender<InferenceEvent>`;
|
||||
/// projection layers in sibling modules consume them and emit
|
||||
/// wire-format-specific frames downstream.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum InferenceEvent {
|
||||
/// The producer has accepted the prompt and is about to emit
|
||||
/// the first token. Sent at most once per stream.
|
||||
Start,
|
||||
/// A piece of visible assistant text. Multiple deltas
|
||||
/// concatenate into the complete reply.
|
||||
TextDelta(String),
|
||||
/// Reasoning / scratchpad text the model emitted inside a
|
||||
/// `<think>` block (or equivalent). The harness routes
|
||||
/// content between marker tokens here so wire projectors can
|
||||
/// decide what to do with it (chat completions drops by
|
||||
/// default; Responses API has a dedicated event family).
|
||||
ReasoningDelta(String),
|
||||
/// A tool call has been parsed out of a `<tool_call>{json}</tool_call>`
|
||||
/// block. Carries the parsed name + arguments JSON string
|
||||
/// (Anthropic / OpenAI projectors emit their own wire shape
|
||||
/// from this).
|
||||
///
|
||||
/// `index` is the call slot — incremented per tool call in a
|
||||
/// turn so wire formats that order calls by index
|
||||
/// (OpenAI chat completions) can correlate.
|
||||
ToolCall {
|
||||
index: usize,
|
||||
id: String,
|
||||
name: String,
|
||||
/// Complete JSON arguments string. The model could in
|
||||
/// principle stream these token-by-token, but our
|
||||
/// extraction buffers the whole block until `</tool_call>`
|
||||
/// arrives and emits exactly one event per call.
|
||||
arguments: String,
|
||||
},
|
||||
/// The stream is complete. Carries the reason so wire formats
|
||||
/// that use it (OpenAI's `finish_reason`, Anthropic's
|
||||
/// `stop_reason`) can render it without re-parsing.
|
||||
Finish { reason: FinishReason },
|
||||
}
|
||||
|
||||
/// Why a stream stopped. Stays small on purpose — anything that
|
||||
/// doesn't map cleanly to one of these collapses to [`Stop`].
|
||||
///
|
||||
/// Mappings to wire formats:
|
||||
///
|
||||
/// | variant | OpenAI `finish_reason` | OpenAI Responses `status` | Anthropic `stop_reason` |
|
||||
/// |---------|------------------------|---------------------------|-------------------------|
|
||||
/// | `Stop` | `"stop"` | `"completed"` | `"end_turn"` |
|
||||
/// | `Length`| `"length"` | `"incomplete"` | `"max_tokens"` |
|
||||
/// | `ToolCalls` | `"tool_calls"` | `"completed"` | `"tool_use"` |
|
||||
///
|
||||
/// [`Stop`]: FinishReason::Stop
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum FinishReason {
|
||||
/// Model emitted EOS naturally.
|
||||
Stop,
|
||||
/// Hit `max_tokens` before EOS.
|
||||
Length,
|
||||
/// Stopped because the model called a tool and is waiting for
|
||||
/// the result. Not yet emitted by the candle harness —
|
||||
/// reserved for the day tool-call extraction lands.
|
||||
#[allow(dead_code)]
|
||||
ToolCalls,
|
||||
}
|
||||
|
||||
impl FinishReason {
|
||||
/// String form used by OpenAI chat completions and OpenAI
|
||||
/// completions. Wire modules can call this directly or do their
|
||||
/// own mapping for non-string formats.
|
||||
pub fn as_openai_str(self) -> &'static str {
|
||||
match self {
|
||||
FinishReason::Stop => "stop",
|
||||
FinishReason::Length => "length",
|
||||
FinishReason::ToolCalls => "tool_calls",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Open/close token IDs for the reasoning marker a loaded model uses
|
||||
/// (or `None` for non-reasoning models). The harness reads this once
|
||||
/// at load time from the tokenizer's added-tokens table, then the
|
||||
/// inference loop checks `next_token` against the pair to flip
|
||||
/// between [`InferenceEvent::TextDelta`] and
|
||||
/// [`InferenceEvent::ReasoningDelta`].
|
||||
///
|
||||
/// `open` and `close` text are kept alongside the IDs so wire
|
||||
/// projectors that want to re-emit the literal markers (the
|
||||
/// opt-in `include_thinking` path on chat completions) don't have
|
||||
/// to reach back into the tokenizer for the strings.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReasoningTokenPair {
|
||||
pub open_id: u32,
|
||||
pub close_id: u32,
|
||||
pub open_text: String,
|
||||
pub close_text: String,
|
||||
}
|
||||
|
||||
/// Known reasoning-marker conventions. Each is a `(open, close)`
|
||||
/// pair of literal token strings. Each modern reasoning model
|
||||
/// declares its markers in the tokenizer's `added_tokens` table;
|
||||
/// at load time we probe for whichever pair the loaded tokenizer
|
||||
/// has and stash both IDs.
|
||||
///
|
||||
/// Ordering matters only for tie-breaking when a model declares
|
||||
/// multiple pairs (shouldn't happen in practice); the first hit
|
||||
/// wins.
|
||||
const KNOWN_REASONING_MARKERS: &[(&str, &str)] = &[
|
||||
// Qwen3, DeepSeek-R1, gpt-oss, and most other open-weight
|
||||
// reasoning models.
|
||||
("<think>", "</think>"),
|
||||
// Mistral Magistral.
|
||||
("[THINK]", "[/THINK]"),
|
||||
// Some older derivatives; harmless to probe.
|
||||
("<thought>", "</thought>"),
|
||||
("<reasoning>", "</reasoning>"),
|
||||
];
|
||||
|
||||
/// Open/close token IDs for the model's tool-call marker
|
||||
/// convention (or `None` for models that don't emit structured
|
||||
/// tool calls). Same shape as [`ReasoningTokenPair`]: probed once
|
||||
/// at load time, consumed by the inference loop to switch between
|
||||
/// "emit visible deltas" and "buffer JSON for the next tool
|
||||
/// call".
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallTokenPair {
|
||||
pub open_id: u32,
|
||||
pub close_id: u32,
|
||||
pub open_text: String,
|
||||
pub close_text: String,
|
||||
}
|
||||
|
||||
/// Tool-call marker conventions. Open-weight tool-use models
|
||||
/// converged on `<tool_call>` / `</tool_call>` (Qwen3-Coder /
|
||||
/// -Instruct, the Hermes function-call format, DeepSeek-Coder,
|
||||
/// gpt-oss). The pair lives alongside the reasoning markers in
|
||||
/// the same `added_tokens` table.
|
||||
const KNOWN_TOOL_CALL_MARKERS: &[(&str, &str)] = &[("<tool_call>", "</tool_call>")];
|
||||
|
||||
/// Probe a tokenizer for known tool-call marker pairs. Mirrors
|
||||
/// [`detect_reasoning_token_pair`] — both open AND close must
|
||||
/// resolve for the pair to be returned. `None` means the model
|
||||
/// doesn't emit structured tool calls (or its tokenizer split
|
||||
/// the markers across tokens).
|
||||
pub fn detect_tool_call_token_pair<F>(token_to_id: F) -> Option<ToolCallTokenPair>
|
||||
where
|
||||
F: Fn(&str) -> Option<u32>,
|
||||
{
|
||||
for (open_text, close_text) in KNOWN_TOOL_CALL_MARKERS {
|
||||
let open_id = token_to_id(open_text);
|
||||
let close_id = token_to_id(close_text);
|
||||
if let (Some(open_id), Some(close_id)) = (open_id, close_id) {
|
||||
return Some(ToolCallTokenPair {
|
||||
open_id,
|
||||
close_id,
|
||||
open_text: (*open_text).into(),
|
||||
close_text: (*close_text).into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Inspect a tokenizer for known reasoning-marker pairs and return
|
||||
/// the first match. The tokenizer types this trait is defined over
|
||||
/// just need to expose `token_to_id(&str) -> Option<u32>` so this
|
||||
/// stays decoupled from the candle crate — the production caller
|
||||
/// passes a `tokenizers::Tokenizer`, but tests can fake one.
|
||||
///
|
||||
/// Returns `None` when no known marker pair is fully declared
|
||||
/// (both open AND close token ids must resolve). That's the
|
||||
/// pass-through case — non-reasoning models, or reasoning models
|
||||
/// whose tokenizer split the markers across multiple tokens (rare
|
||||
/// in practice; modern reasoning tokenizers list them as
|
||||
/// `added_tokens`).
|
||||
pub fn detect_reasoning_token_pair<F>(token_to_id: F) -> Option<ReasoningTokenPair>
|
||||
where
|
||||
F: Fn(&str) -> Option<u32>,
|
||||
{
|
||||
for (open_text, close_text) in KNOWN_REASONING_MARKERS {
|
||||
let open_id = token_to_id(open_text);
|
||||
let close_id = token_to_id(close_text);
|
||||
if let (Some(open_id), Some(close_id)) = (open_id, close_id) {
|
||||
return Some(ReasoningTokenPair {
|
||||
open_id,
|
||||
close_id,
|
||||
open_text: (*open_text).into(),
|
||||
close_text: (*close_text).into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn lookup<'a>(map: &'a HashMap<&'static str, u32>) -> impl Fn(&str) -> Option<u32> + 'a {
|
||||
|s| map.get(s).copied()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_qwen3_style_think_markers() {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("<think>", 151648);
|
||||
m.insert("</think>", 151649);
|
||||
let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected");
|
||||
assert_eq!(pair.open_id, 151648);
|
||||
assert_eq!(pair.close_id, 151649);
|
||||
assert_eq!(pair.open_text, "<think>");
|
||||
assert_eq!(pair.close_text, "</think>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_mistral_magistral_markers() {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("[THINK]", 100);
|
||||
m.insert("[/THINK]", 101);
|
||||
let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected");
|
||||
assert_eq!(pair.open_text, "[THINK]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_when_only_open_marker_present() {
|
||||
// A pathological tokenizer that has `<think>` but not
|
||||
// `</think>` shouldn't half-detect. Pass-through.
|
||||
let mut m = HashMap::new();
|
||||
m.insert("<think>", 1);
|
||||
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_for_non_reasoning_tokenizer() {
|
||||
let m: HashMap<&'static str, u32> = HashMap::new();
|
||||
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_tool_call_markers() {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("<tool_call>", 151657);
|
||||
m.insert("</tool_call>", 151658);
|
||||
let pair = detect_tool_call_token_pair(lookup(&m)).expect("pair detected");
|
||||
assert_eq!(pair.open_id, 151657);
|
||||
assert_eq!(pair.close_id, 151658);
|
||||
assert_eq!(pair.open_text, "<tool_call>");
|
||||
assert_eq!(pair.close_text, "</tool_call>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_for_non_tool_use_tokenizer() {
|
||||
let m: HashMap<&'static str, u32> = HashMap::new();
|
||||
assert!(detect_tool_call_token_pair(lookup(&m)).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_match_wins_when_multiple_pairs_declared() {
|
||||
// Hypothetical tokenizer with both Qwen-style AND Mistral-style
|
||||
// markers — the `<think>` pair is earlier in the convention
|
||||
// table so it wins.
|
||||
let mut m = HashMap::new();
|
||||
m.insert("<think>", 1);
|
||||
m.insert("</think>", 2);
|
||||
m.insert("[THINK]", 3);
|
||||
m.insert("[/THINK]", 4);
|
||||
let pair = detect_reasoning_token_pair(lookup(&m)).unwrap();
|
||||
assert_eq!(pair.open_id, 1);
|
||||
assert_eq!(pair.close_id, 2);
|
||||
}
|
||||
}
|
||||
27
crates/neuron/src/wire/mod.rs
Normal file
27
crates/neuron/src/wire/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
//! Wire-format projection layer.
|
||||
//!
|
||||
//! The candle harness produces a single, format-agnostic stream of
|
||||
//! [`InferenceEvent`]s. Each wire format (OpenAI chat completions,
|
||||
//! OpenAI Responses, Anthropic messages, …) lives in its own module
|
||||
//! under `wire::` and projects that event stream into the chunks /
|
||||
//! events its HTTP clients expect.
|
||||
//!
|
||||
//! The benefit over translating *between* wire shapes (OpenAI chat
|
||||
//! → Anthropic, etc.) is that we never have to reason about a
|
||||
//! wire-N → wire-M conversion: every translation is wire-N ↔ the
|
||||
//! internal event currency, and the projections are independent. A
|
||||
//! new wire format adds a new file under `wire::`; nothing else
|
||||
//! needs to know about it.
|
||||
//!
|
||||
//! Today: [`openai_chat`]. Stage 2 adds `openai_responses`. Stage 3
|
||||
//! could add a native Anthropic projection that replaces the
|
||||
//! gateway-side translation.
|
||||
|
||||
pub mod event;
|
||||
pub mod openai_chat;
|
||||
pub mod openai_responses;
|
||||
|
||||
pub use event::{
|
||||
FinishReason, InferenceEvent, ReasoningTokenPair, ToolCallTokenPair,
|
||||
detect_reasoning_token_pair, detect_tool_call_token_pair,
|
||||
};
|
||||
558
crates/neuron/src/wire/openai_chat.rs
Normal file
558
crates/neuron/src/wire/openai_chat.rs
Normal file
@@ -0,0 +1,558 @@
|
||||
//! OpenAI chat completions projection.
|
||||
//!
|
||||
//! Reads [`InferenceEvent`]s from a receiver and produces
|
||||
//! [`ChatCompletionChunk`]s in the shape `POST /v1/chat/completions`
|
||||
//! clients expect on its streaming SSE response. The HTTP handler in
|
||||
//! [`crate::api`] wraps the resulting receiver in axum's
|
||||
//! `Sse::new(...)` adapter; nothing in this module touches HTTP
|
||||
//! framing or `data:` lines.
|
||||
//!
|
||||
//! Per the OpenAI streaming spec, three chunk shapes appear:
|
||||
//!
|
||||
//! 1. **Role chunk** — `delta: { "role": "assistant" }`, no content,
|
||||
//! sent once at stream start. We emit this on [`InferenceEvent::Start`].
|
||||
//! 2. **Content chunks** — `delta: { "content": "<text>" }`, one per
|
||||
//! [`InferenceEvent::TextDelta`].
|
||||
//! 3. **Final chunk** — empty `delta`, `finish_reason` populated.
|
||||
//! Emitted on [`InferenceEvent::Finish`].
|
||||
//!
|
||||
//! `usage` stays `None` on every chunk; the legacy candle paths
|
||||
//! never surfaced usage on the streaming endpoint and we keep that
|
||||
//! behaviour bit-for-bit so existing clients see no diff.
|
||||
//!
|
||||
//! Back-pressure: the projection task awaits both `rx.recv()` and
|
||||
//! `tx.send()`. A slow consumer fills the output channel → the
|
||||
//! task blocks on send → it stops reading from the input → the
|
||||
//! producer blocks on its own send. The bounded channels
|
||||
//! propagate without us writing any logic.
|
||||
|
||||
use cortex_core::openai::{ChatCompletionChunk, ChunkChoice};
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::event::{FinishReason, InferenceEvent, ReasoningTokenPair};
|
||||
|
||||
/// Output channel buffer size. Mirrors the input side's bound; one
|
||||
/// event maps to at most one chunk, so equal capacity keeps the
|
||||
/// two ends in sync without surprising memory growth.
|
||||
const CHUNK_CHANNEL_CAPACITY: usize = 32;
|
||||
|
||||
/// Per-stream config for the chat projector. Used by the
|
||||
/// production handler to thread per-request choices (currently:
|
||||
/// whether to surface reasoning content) into the projection
|
||||
/// without bloating the function signature.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChatProjectionConfig {
|
||||
/// When `true`, reasoning content is re-wrapped with the
|
||||
/// model's literal open/close markers and emitted as content
|
||||
/// deltas — preserving the on-the-wire shape that
|
||||
/// reasoning-aware clients like helexa-acp's `ThinkParser`
|
||||
/// expect.
|
||||
///
|
||||
/// When `false` (the default), [`InferenceEvent::ReasoningDelta`]s
|
||||
/// are dropped entirely so consumers that don't know about
|
||||
/// reasoning (Zed's commit-message generator, any vanilla
|
||||
/// OpenAI client) don't have model-internal scratchpad
|
||||
/// material leaking into their UI. The chat-completions wire
|
||||
/// format has no slot for reasoning, so the default chooses
|
||||
/// the safer-for-naïve-clients behaviour.
|
||||
pub include_thinking: bool,
|
||||
/// Open/close marker strings to re-emit when `include_thinking`
|
||||
/// is set. Sourced from the loaded model's
|
||||
/// [`ReasoningTokenPair`]; `None` for non-reasoning models or
|
||||
/// when the caller doesn't have the pair handy (in which case
|
||||
/// `include_thinking` becomes equivalent to dropping reasoning
|
||||
/// because there's nothing to wrap).
|
||||
pub reasoning_markers: Option<ReasoningTokenPair>,
|
||||
}
|
||||
|
||||
/// Project an [`InferenceEvent`] receiver into a
|
||||
/// [`ChatCompletionChunk`] receiver. Spawns one tokio task that
|
||||
/// owns the input receiver for the stream's lifetime and exits
|
||||
/// when either side closes.
|
||||
///
|
||||
/// `id`, `created`, and `model_id` are stamped into every emitted
|
||||
/// chunk so the receiver can stay generic (decoupled from
|
||||
/// per-request metadata).
|
||||
pub fn project_chat_stream(
|
||||
rx: mpsc::Receiver<InferenceEvent>,
|
||||
id: String,
|
||||
created: u64,
|
||||
model_id: String,
|
||||
) -> mpsc::Receiver<ChatCompletionChunk> {
|
||||
// Default config: include_thinking off, no marker rewrap.
|
||||
project_chat_stream_with(rx, id, created, model_id, ChatProjectionConfig::default())
|
||||
}
|
||||
|
||||
/// Same as [`project_chat_stream`] but with a per-stream config
|
||||
/// (currently controlling reasoning surfacing). Production
|
||||
/// callers that need the opt-in path call this directly; the
|
||||
/// shorter wrapper above stays as the no-config convenience.
|
||||
pub fn project_chat_stream_with(
|
||||
mut rx: mpsc::Receiver<InferenceEvent>,
|
||||
id: String,
|
||||
created: u64,
|
||||
model_id: String,
|
||||
config: ChatProjectionConfig,
|
||||
) -> mpsc::Receiver<ChatCompletionChunk> {
|
||||
let (tx, out_rx) = mpsc::channel::<ChatCompletionChunk>(CHUNK_CHANNEL_CAPACITY);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Track whether the previous event was inside a reasoning
|
||||
// block — used to decide when to emit the literal close
|
||||
// marker on the include_thinking re-wrap path. When this
|
||||
// flips from true → false (a TextDelta or Finish lands
|
||||
// after one or more ReasoningDeltas), we emit the close
|
||||
// marker exactly once.
|
||||
let mut was_in_reasoning = false;
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
// Close-marker insertion: if we're leaving a reasoning
|
||||
// chain, emit the literal close marker before the
|
||||
// current event.
|
||||
if was_in_reasoning && !matches!(event, InferenceEvent::ReasoningDelta(_)) {
|
||||
if let Some(marker) = config
|
||||
.include_thinking
|
||||
.then_some(())
|
||||
.and(config.reasoning_markers.as_ref())
|
||||
{
|
||||
let chunk = content_chunk(&id, created, &model_id, &marker.close_text);
|
||||
if tx.send(chunk).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
was_in_reasoning = false;
|
||||
}
|
||||
|
||||
let chunks = match event {
|
||||
InferenceEvent::Start => vec![role_chunk(&id, created, &model_id)],
|
||||
InferenceEvent::TextDelta(text) => {
|
||||
if text.is_empty() {
|
||||
// DecodeStream is buffering a multi-byte
|
||||
// codepoint; don't bother sending an empty
|
||||
// chunk downstream.
|
||||
continue;
|
||||
}
|
||||
vec![content_chunk(&id, created, &model_id, &text)]
|
||||
}
|
||||
InferenceEvent::ReasoningDelta(text) => {
|
||||
if !config.include_thinking {
|
||||
// Default path — reasoning has no slot in
|
||||
// chat completions, so it's dropped. Naïve
|
||||
// clients (Zed commit-message generator,
|
||||
// any vanilla OpenAI client) get clean
|
||||
// output.
|
||||
continue;
|
||||
}
|
||||
let Some(markers) = config.reasoning_markers.as_ref() else {
|
||||
// Caller asked to include thinking but
|
||||
// didn't supply markers — best we can do
|
||||
// is emit the content as visible text.
|
||||
// Skip the wrap entirely.
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let chunk = content_chunk(&id, created, &model_id, &text);
|
||||
if tx.send(chunk).await.is_err() {
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
};
|
||||
// First chunk of a reasoning block → open
|
||||
// marker prelude. Subsequent reasoning deltas
|
||||
// in the same block reuse `was_in_reasoning`
|
||||
// to skip the prelude.
|
||||
let mut chunks = Vec::new();
|
||||
if !was_in_reasoning {
|
||||
chunks.push(content_chunk(&id, created, &model_id, &markers.open_text));
|
||||
}
|
||||
if !text.is_empty() {
|
||||
chunks.push(content_chunk(&id, created, &model_id, &text));
|
||||
}
|
||||
was_in_reasoning = true;
|
||||
chunks
|
||||
}
|
||||
InferenceEvent::ToolCall {
|
||||
index,
|
||||
id: call_id,
|
||||
name,
|
||||
arguments,
|
||||
} => {
|
||||
// OpenAI streaming shape for tool calls:
|
||||
// `delta.tool_calls[]` with id + function.name
|
||||
// on the first chunk per index, then
|
||||
// function.arguments deltas. We have the
|
||||
// complete arguments buffered already, so one
|
||||
// delta carries everything.
|
||||
vec![tool_call_chunk(
|
||||
&id, created, &model_id, index, &call_id, &name, &arguments,
|
||||
)]
|
||||
}
|
||||
InferenceEvent::Finish { reason } => {
|
||||
vec![final_chunk(&id, created, &model_id, reason)]
|
||||
}
|
||||
};
|
||||
for chunk in chunks {
|
||||
if tx.send(chunk).await.is_err() {
|
||||
// Consumer hung up; nothing more to do.
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
out_rx
|
||||
}
|
||||
|
||||
fn role_chunk(id: &str, created: u64, model_id: &str) -> ChatCompletionChunk {
|
||||
ChatCompletionChunk {
|
||||
id: id.into(),
|
||||
object: "chat.completion.chunk".into(),
|
||||
created,
|
||||
model: model_id.into(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta: json!({ "role": "assistant" }),
|
||||
finish_reason: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}],
|
||||
usage: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn content_chunk(id: &str, created: u64, model_id: &str, text: &str) -> ChatCompletionChunk {
|
||||
ChatCompletionChunk {
|
||||
id: id.into(),
|
||||
object: "chat.completion.chunk".into(),
|
||||
created,
|
||||
model: model_id.into(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta: json!({ "content": text }),
|
||||
finish_reason: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}],
|
||||
usage: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenAI chat streaming shape for a tool call. One chunk per
|
||||
/// call slot, carrying id + name + the complete arguments JSON.
|
||||
/// Mirrors the format real OpenAI emits on the streaming path,
|
||||
/// minus the per-token arguments-streaming complication (we have
|
||||
/// the whole buffer already after the model finishes the
|
||||
/// `<tool_call>...</tool_call>` block).
|
||||
fn tool_call_chunk(
|
||||
id: &str,
|
||||
created: u64,
|
||||
model_id: &str,
|
||||
index: usize,
|
||||
call_id: &str,
|
||||
name: &str,
|
||||
arguments: &str,
|
||||
) -> ChatCompletionChunk {
|
||||
ChatCompletionChunk {
|
||||
id: id.into(),
|
||||
object: "chat.completion.chunk".into(),
|
||||
created,
|
||||
model: model_id.into(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta: json!({
|
||||
"tool_calls": [{
|
||||
"index": index,
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
}
|
||||
}],
|
||||
}),
|
||||
finish_reason: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}],
|
||||
usage: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn final_chunk(
|
||||
id: &str,
|
||||
created: u64,
|
||||
model_id: &str,
|
||||
reason: FinishReason,
|
||||
) -> ChatCompletionChunk {
|
||||
ChatCompletionChunk {
|
||||
id: id.into(),
|
||||
object: "chat.completion.chunk".into(),
|
||||
created,
|
||||
model: model_id.into(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta: serde_json::Value::Object(Default::default()),
|
||||
finish_reason: Some(reason.as_openai_str().to_string()),
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}],
|
||||
usage: None,
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Drain the projection's output into a Vec for assertion.
|
||||
async fn collect(mut rx: mpsc::Receiver<ChatCompletionChunk>) -> Vec<ChatCompletionChunk> {
|
||||
let mut out = Vec::new();
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
out.push(chunk);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_event_stream_yields_no_chunks() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
drop(tx);
|
||||
let out = collect(project_chat_stream(rx, "id-1".into(), 1700, "m".into())).await;
|
||||
assert!(out.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_text_finish_produces_three_chunks() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream(rx, "id-1".into(), 1700, "m".into());
|
||||
|
||||
tx.send(InferenceEvent::Start).await.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta("hello".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Stop,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
let out = collect(out_rx).await;
|
||||
assert_eq!(out.len(), 3);
|
||||
assert_eq!(out[0].choices[0].delta["role"], "assistant");
|
||||
assert_eq!(out[1].choices[0].delta["content"], "hello");
|
||||
assert_eq!(out[2].choices[0].finish_reason.as_deref(), Some("stop"));
|
||||
// Every chunk carries the stamped metadata.
|
||||
for chunk in &out {
|
||||
assert_eq!(chunk.id, "id-1");
|
||||
assert_eq!(chunk.created, 1700);
|
||||
assert_eq!(chunk.model, "m");
|
||||
assert_eq!(chunk.object, "chat.completion.chunk");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_text_delta_is_dropped() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream(rx, "id".into(), 1, "m".into());
|
||||
tx.send(InferenceEvent::TextDelta(String::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
assert!(out.is_empty(), "empty deltas must not produce chunks");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn finish_length_maps_to_openai_string() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream(rx, "id".into(), 1, "m".into());
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Length,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
assert_eq!(out.len(), 1);
|
||||
assert_eq!(out[0].choices[0].finish_reason.as_deref(), Some("length"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reasoning_delta_is_dropped_in_chat_projection() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream(rx, "id".into(), 1, "m".into());
|
||||
tx.send(InferenceEvent::ReasoningDelta("<think>".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta("real".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
assert_eq!(out.len(), 1);
|
||||
assert_eq!(out[0].choices[0].delta["content"], "real");
|
||||
}
|
||||
|
||||
fn pair() -> ReasoningTokenPair {
|
||||
ReasoningTokenPair {
|
||||
open_id: 0,
|
||||
close_id: 1,
|
||||
open_text: "<think>".into(),
|
||||
close_text: "</think>".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn include_thinking_rewraps_reasoning_with_literal_markers() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
|
||||
let out_rx = project_chat_stream_with(
|
||||
rx,
|
||||
"id".into(),
|
||||
1,
|
||||
"m".into(),
|
||||
ChatProjectionConfig {
|
||||
include_thinking: true,
|
||||
reasoning_markers: Some(pair()),
|
||||
},
|
||||
);
|
||||
tx.send(InferenceEvent::ReasoningDelta("first ".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::ReasoningDelta("second".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta("answer".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Stop,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
// Expected sequence: open marker → reasoning content (2 chunks)
|
||||
// → close marker → visible answer → final chunk.
|
||||
let contents: Vec<&str> = out
|
||||
.iter()
|
||||
.filter_map(|c| c.choices[0].delta["content"].as_str())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
contents,
|
||||
vec!["<think>", "first ", "second", "</think>", "answer"]
|
||||
);
|
||||
assert_eq!(
|
||||
out.last().unwrap().choices[0].finish_reason.as_deref(),
|
||||
Some("stop")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn include_thinking_closes_marker_at_finish_when_no_trailing_text() {
|
||||
// Edge case: stream ends inside a reasoning block (model
|
||||
// hit max_tokens mid-thought, no visible answer ever).
|
||||
// The Finish event still triggers the close marker so the
|
||||
// stream is balanced.
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream_with(
|
||||
rx,
|
||||
"id".into(),
|
||||
1,
|
||||
"m".into(),
|
||||
ChatProjectionConfig {
|
||||
include_thinking: true,
|
||||
reasoning_markers: Some(pair()),
|
||||
},
|
||||
);
|
||||
tx.send(InferenceEvent::ReasoningDelta("thinking...".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Length,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
let contents: Vec<&str> = out
|
||||
.iter()
|
||||
.filter_map(|c| c.choices[0].delta["content"].as_str())
|
||||
.collect();
|
||||
assert_eq!(contents, vec!["<think>", "thinking...", "</think>"]);
|
||||
assert_eq!(
|
||||
out.last().unwrap().choices[0].finish_reason.as_deref(),
|
||||
Some("length")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn include_thinking_without_markers_emits_content_directly() {
|
||||
// Defensive: if the caller asks for thinking but the
|
||||
// model declared no markers, we still emit the content
|
||||
// rather than dropping it. Better to leak than to lose.
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream_with(
|
||||
rx,
|
||||
"id".into(),
|
||||
1,
|
||||
"m".into(),
|
||||
ChatProjectionConfig {
|
||||
include_thinking: true,
|
||||
reasoning_markers: None,
|
||||
},
|
||||
);
|
||||
tx.send(InferenceEvent::ReasoningDelta("raw".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Stop,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
let contents: Vec<&str> = out
|
||||
.iter()
|
||||
.filter_map(|c| c.choices[0].delta["content"].as_str())
|
||||
.collect();
|
||||
assert_eq!(contents, vec!["raw"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn include_thinking_off_drops_reasoning_even_with_markers() {
|
||||
// Default behaviour even when markers happen to be
|
||||
// configured. The flag is the gate, not the marker
|
||||
// presence.
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
|
||||
let out_rx = project_chat_stream_with(
|
||||
rx,
|
||||
"id".into(),
|
||||
1,
|
||||
"m".into(),
|
||||
ChatProjectionConfig {
|
||||
include_thinking: false,
|
||||
reasoning_markers: Some(pair()),
|
||||
},
|
||||
);
|
||||
tx.send(InferenceEvent::ReasoningDelta("hidden".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta("visible".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Stop,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let out = collect(out_rx).await;
|
||||
let contents: Vec<&str> = out
|
||||
.iter()
|
||||
.filter_map(|c| c.choices[0].delta["content"].as_str())
|
||||
.collect();
|
||||
assert_eq!(contents, vec!["visible"]);
|
||||
}
|
||||
}
|
||||
870
crates/neuron/src/wire/openai_responses.rs
Normal file
870
crates/neuron/src/wire/openai_responses.rs
Normal file
@@ -0,0 +1,870 @@
|
||||
//! OpenAI Responses API projection.
|
||||
//!
|
||||
//! Two responsibilities:
|
||||
//!
|
||||
//! 1. **Translate request shape**: [`request_to_chat`] flattens
|
||||
//! [`ResponsesRequest`]'s typed `input` items + `instructions`
|
||||
//! into the [`ChatCompletionRequest`] the candle harness already
|
||||
//! knows how to run. The Responses-specific shape stops at this
|
||||
//! function — everything downstream is the same chat path the
|
||||
//! `/v1/chat/completions` route exercises.
|
||||
//!
|
||||
//! 2. **Project event stream**: [`project_responses_stream`] reads
|
||||
//! [`InferenceEvent`]s from the harness and emits the named SSE
|
||||
//! events the Responses API client expects
|
||||
//! (`response.created`, `response.output_text.delta`,
|
||||
//! `response.completed`, …) along with their JSON payloads.
|
||||
//! The HTTP handler in [`crate::api`] reads
|
||||
//! `(event_name, data)` tuples off the receiver and stamps them
|
||||
//! onto axum SSE frames.
|
||||
//!
|
||||
//! Scope cuts (carried over from [`cortex_core::responses`]):
|
||||
//!
|
||||
//! - `previous_response_id` is rejected by [`request_to_chat`]
|
||||
//! with [`TranslateError::ChainedConversationNotSupported`].
|
||||
//! - `Reasoning` input items are dropped (no equivalent in chat).
|
||||
//! - `FunctionCall` / `FunctionCallOutput` items round-trip but the
|
||||
//! harness never emits tool calls today; the synthesis paths are
|
||||
//! in place so the surface is ready when it does.
|
||||
|
||||
use cortex_core::openai::{ChatCompletionRequest, ChatMessage, MessageContent};
|
||||
use cortex_core::responses::{
|
||||
ResponsesContentPart, ResponsesInput, ResponsesInputItem, ResponsesMessageContent,
|
||||
ResponsesOutputContent, ResponsesOutputItem, ResponsesRequest, ResponsesResponse,
|
||||
ResponsesUsage, events,
|
||||
};
|
||||
use serde_json::{Value, json};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::event::{FinishReason, InferenceEvent};
|
||||
|
||||
/// Per-request metadata that has to be stamped into every emitted
|
||||
/// event. The projector spawns a task that owns one of these.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponseMeta {
|
||||
pub response_id: String,
|
||||
pub created_at: u64,
|
||||
pub model_id: String,
|
||||
/// Item id used inside `output[0]` (the message). All
|
||||
/// `content_part.*` and `output_text.*` events reference this
|
||||
/// so the consumer knows which item the delta belongs to.
|
||||
pub message_item_id: String,
|
||||
}
|
||||
|
||||
/// Reasons [`request_to_chat`] refuses a request.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TranslateError {
|
||||
#[error(
|
||||
"previous_response_id is not supported on this neuron; chained \
|
||||
conversations require server-side state we don't store yet"
|
||||
)]
|
||||
ChainedConversationNotSupported,
|
||||
}
|
||||
|
||||
/// Flatten a [`ResponsesRequest`] into the chat-completions shape
|
||||
/// the candle harness already knows how to drive. Keeps the
|
||||
/// Responses-specific machinery contained to a single function so
|
||||
/// the harness stays format-agnostic.
|
||||
///
|
||||
/// Semantics:
|
||||
///
|
||||
/// - `instructions` (if set) becomes a leading `system` message.
|
||||
/// - `input: "<string>"` becomes a single `user` message.
|
||||
/// - `input: [items]` flattens each item:
|
||||
/// - `Message { role, content }` → one `ChatMessage`.
|
||||
/// - `FunctionCall` → an `assistant` turn whose `extra.tool_calls`
|
||||
/// carries the call (chat-completions-shaped). The harness
|
||||
/// doesn't act on tool_calls today, but the shape stays
|
||||
/// consistent with what chat would expect.
|
||||
/// - `FunctionCallOutput` → a `tool` role message with the
|
||||
/// output text. Matches OpenAI's chat convention.
|
||||
/// - `Reasoning` items are dropped (no equivalent in chat).
|
||||
/// - Text parts within an array `content` collapse to a single
|
||||
/// string; image parts get rendered as a chat-style content
|
||||
/// array `[{type:"text"}, {type:"image_url"}]` so the chat
|
||||
/// handler's existing vision path applies.
|
||||
pub fn request_to_chat(req: ResponsesRequest) -> Result<ChatCompletionRequest, TranslateError> {
|
||||
if req.previous_response_id.is_some() {
|
||||
return Err(TranslateError::ChainedConversationNotSupported);
|
||||
}
|
||||
|
||||
let mut messages: Vec<ChatMessage> = Vec::new();
|
||||
|
||||
if let Some(instructions) = req.instructions
|
||||
&& !instructions.is_empty()
|
||||
{
|
||||
messages.push(ChatMessage {
|
||||
role: "system".into(),
|
||||
content: MessageContent::Text(instructions),
|
||||
extra: Value::Object(Default::default()),
|
||||
});
|
||||
}
|
||||
|
||||
match req.input {
|
||||
ResponsesInput::Text(text) => {
|
||||
messages.push(ChatMessage {
|
||||
role: "user".into(),
|
||||
content: MessageContent::Text(text),
|
||||
extra: Value::Object(Default::default()),
|
||||
});
|
||||
}
|
||||
ResponsesInput::Items(items) => {
|
||||
for item in items {
|
||||
if let Some(msg) = input_item_to_chat(item) {
|
||||
messages.push(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ChatCompletionRequest {
|
||||
model: req.model,
|
||||
messages,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
max_tokens: req.max_output_tokens,
|
||||
stream: Some(req.stream),
|
||||
extra: Value::Object(Default::default()),
|
||||
})
|
||||
}
|
||||
|
||||
fn input_item_to_chat(item: ResponsesInputItem) -> Option<ChatMessage> {
|
||||
match item {
|
||||
ResponsesInputItem::Message { role, content } => Some(ChatMessage {
|
||||
role,
|
||||
content: message_content_to_chat(content),
|
||||
extra: Value::Object(Default::default()),
|
||||
}),
|
||||
ResponsesInputItem::FunctionCall {
|
||||
call_id,
|
||||
name,
|
||||
arguments,
|
||||
} => {
|
||||
// Express the call in chat-completions shape via
|
||||
// `extra.tool_calls`. The harness ignores it today but
|
||||
// the shape is consistent for the day it doesn't.
|
||||
let mut extra = serde_json::Map::new();
|
||||
extra.insert(
|
||||
"tool_calls".into(),
|
||||
json!([{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": { "name": name, "arguments": arguments },
|
||||
}]),
|
||||
);
|
||||
Some(ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: MessageContent::Text(String::new()),
|
||||
extra: Value::Object(extra),
|
||||
})
|
||||
}
|
||||
ResponsesInputItem::FunctionCallOutput { call_id, output } => {
|
||||
let mut extra = serde_json::Map::new();
|
||||
extra.insert("tool_call_id".into(), Value::String(call_id));
|
||||
Some(ChatMessage {
|
||||
role: "tool".into(),
|
||||
content: MessageContent::Text(output),
|
||||
extra: Value::Object(extra),
|
||||
})
|
||||
}
|
||||
// Reasoning items don't have a chat-completions equivalent
|
||||
// we can faithfully forward. Silently drop — the alternative
|
||||
// is rejecting a well-formed request, which is worse UX.
|
||||
ResponsesInputItem::Reasoning { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn message_content_to_chat(content: ResponsesMessageContent) -> MessageContent {
|
||||
match content {
|
||||
ResponsesMessageContent::Text(s) => MessageContent::Text(s),
|
||||
ResponsesMessageContent::Parts(parts) => {
|
||||
// Collapse to a string when every part is text; emit
|
||||
// the chat content-array shape only when an image is
|
||||
// present (some upstreams treat the array form as a
|
||||
// vision-only signal and reject it for text-only
|
||||
// models).
|
||||
let has_image = parts
|
||||
.iter()
|
||||
.any(|p| matches!(p, ResponsesContentPart::InputImage { .. }));
|
||||
if !has_image {
|
||||
let joined = parts
|
||||
.into_iter()
|
||||
.filter_map(|p| match p {
|
||||
ResponsesContentPart::InputText { text }
|
||||
| ResponsesContentPart::OutputText { text, .. } => Some(text),
|
||||
ResponsesContentPart::InputImage { .. } => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n");
|
||||
return MessageContent::Text(joined);
|
||||
}
|
||||
let mut out: Vec<Value> = Vec::with_capacity(parts.len());
|
||||
for p in parts {
|
||||
match p {
|
||||
ResponsesContentPart::InputText { text }
|
||||
| ResponsesContentPart::OutputText { text, .. } => {
|
||||
out.push(json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ResponsesContentPart::InputImage { image_url, .. } => {
|
||||
out.push(json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": image_url },
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::Parts(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Streaming projection ─────────────────────────────────────────────
|
||||
|
||||
/// One frame the projector emits. The HTTP handler maps each into
|
||||
/// an axum `Sse::Event` with both an `event:` name and a `data:`
|
||||
/// JSON payload — Responses, unlike chat completions, uses named
|
||||
/// SSE events.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponseStreamFrame {
|
||||
pub event_name: &'static str,
|
||||
pub data: Value,
|
||||
}
|
||||
|
||||
/// Project an [`InferenceEvent`] receiver into a stream of
|
||||
/// [`ResponseStreamFrame`]s. The emitted sequence per stream is:
|
||||
///
|
||||
/// 1. `response.created` — shell with `status: "in_progress"`.
|
||||
/// 2. `response.output_item.added` — empty message item.
|
||||
/// 3. `response.content_part.added` — empty `output_text` part.
|
||||
/// 4. `response.output_text.delta` × N — token-by-token text.
|
||||
/// 5. `response.output_text.done` — full accumulated text.
|
||||
/// 6. `response.content_part.done` — full part payload.
|
||||
/// 7. `response.output_item.done` — full message item.
|
||||
/// 8. `response.completed` — final response with `status:"completed"`.
|
||||
///
|
||||
/// Empty TextDeltas (the harness's incomplete-UTF-8 buffering) are
|
||||
/// dropped. `ReasoningDelta`s have no representation in the
|
||||
/// Responses API spec we model yet, so they're dropped too.
|
||||
pub fn project_responses_stream(
|
||||
rx: mpsc::Receiver<InferenceEvent>,
|
||||
meta: ResponseMeta,
|
||||
) -> mpsc::Receiver<ResponseStreamFrame> {
|
||||
let (tx, out_rx) = mpsc::channel::<ResponseStreamFrame>(64);
|
||||
tokio::spawn(async move {
|
||||
run_projection(rx, meta, tx).await;
|
||||
});
|
||||
out_rx
|
||||
}
|
||||
|
||||
async fn run_projection(
|
||||
mut rx: mpsc::Receiver<InferenceEvent>,
|
||||
meta: ResponseMeta,
|
||||
tx: mpsc::Sender<ResponseStreamFrame>,
|
||||
) {
|
||||
let mut accumulated = String::new();
|
||||
let mut finish: Option<FinishReason> = None;
|
||||
let mut emitted_start = false;
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
InferenceEvent::Start => {
|
||||
emitted_start = true;
|
||||
if !emit_start_frames(&tx, &meta).await {
|
||||
return;
|
||||
}
|
||||
}
|
||||
InferenceEvent::TextDelta(text) => {
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
accumulated.push_str(&text);
|
||||
let frame = ResponseStreamFrame {
|
||||
event_name: events::OUTPUT_TEXT_DELTA,
|
||||
data: json!({
|
||||
"item_id": meta.message_item_id,
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"delta": text,
|
||||
}),
|
||||
};
|
||||
if tx.send(frame).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
InferenceEvent::ReasoningDelta(_) => {
|
||||
// No representation in our Responses model yet.
|
||||
// Stage where it'd land: a `response.reasoning_*`
|
||||
// event family alongside `response.output_text.*`.
|
||||
}
|
||||
InferenceEvent::ToolCall { .. } => {
|
||||
// Responses-side tool-call routing not wired yet
|
||||
// (would emit response.function_call_arguments.*
|
||||
// events). Drop for now; the chat-completions
|
||||
// projector handles tool calls. Future work
|
||||
// tracked in #7 alongside the in_progress event.
|
||||
}
|
||||
InferenceEvent::Finish { reason } => {
|
||||
finish = Some(reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Producers can drop without ever sending Start (e.g. early
|
||||
// poisoned-model error). Synthesize the open frames so the
|
||||
// consumer at least sees a coherent shell before completed.
|
||||
if !emitted_start && !emit_start_frames(&tx, &meta).await {
|
||||
return;
|
||||
}
|
||||
|
||||
let reason = finish.unwrap_or(FinishReason::Stop);
|
||||
let _ = emit_finish_frames(&tx, &meta, &accumulated, reason).await;
|
||||
}
|
||||
|
||||
async fn emit_start_frames(tx: &mpsc::Sender<ResponseStreamFrame>, meta: &ResponseMeta) -> bool {
|
||||
let shell = response_shell(meta, "in_progress", &[], None);
|
||||
let frames = [
|
||||
ResponseStreamFrame {
|
||||
event_name: events::CREATED,
|
||||
data: json!({ "response": shell.clone() }),
|
||||
},
|
||||
// `response.in_progress` carries the same shell as
|
||||
// `response.created` — both report the "in_progress"
|
||||
// status and both are payload-light bookkeeping events.
|
||||
// The distinction is meaningful to clients that
|
||||
// differentiate "request validated" from "model is
|
||||
// generating" in their UI (loading spinner vs streaming
|
||||
// spinner). OpenAI's own Responses SSE emits them as a
|
||||
// pair; matching the wire shape avoids subtle client
|
||||
// breakage.
|
||||
ResponseStreamFrame {
|
||||
event_name: events::IN_PROGRESS,
|
||||
data: json!({ "response": shell }),
|
||||
},
|
||||
ResponseStreamFrame {
|
||||
event_name: events::OUTPUT_ITEM_ADDED,
|
||||
data: json!({
|
||||
"output_index": 0,
|
||||
"item": empty_message_item(&meta.message_item_id),
|
||||
}),
|
||||
},
|
||||
ResponseStreamFrame {
|
||||
event_name: events::CONTENT_PART_ADDED,
|
||||
data: json!({
|
||||
"item_id": meta.message_item_id,
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"part": { "type": "output_text", "text": "", "annotations": [] },
|
||||
}),
|
||||
},
|
||||
];
|
||||
for frame in frames {
|
||||
if tx.send(frame).await.is_err() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
async fn emit_finish_frames(
|
||||
tx: &mpsc::Sender<ResponseStreamFrame>,
|
||||
meta: &ResponseMeta,
|
||||
full_text: &str,
|
||||
reason: FinishReason,
|
||||
) -> bool {
|
||||
let status = finish_to_status(reason);
|
||||
let full_part = json!({
|
||||
"type": "output_text",
|
||||
"text": full_text,
|
||||
"annotations": [],
|
||||
});
|
||||
let full_item = json!({
|
||||
"type": "message",
|
||||
"id": meta.message_item_id,
|
||||
"role": "assistant",
|
||||
"content": [full_part.clone()],
|
||||
"status": status,
|
||||
});
|
||||
let frames = [
|
||||
ResponseStreamFrame {
|
||||
event_name: events::OUTPUT_TEXT_DONE,
|
||||
data: json!({
|
||||
"item_id": meta.message_item_id,
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"text": full_text,
|
||||
}),
|
||||
},
|
||||
ResponseStreamFrame {
|
||||
event_name: events::CONTENT_PART_DONE,
|
||||
data: json!({
|
||||
"item_id": meta.message_item_id,
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"part": full_part,
|
||||
}),
|
||||
},
|
||||
ResponseStreamFrame {
|
||||
event_name: events::OUTPUT_ITEM_DONE,
|
||||
data: json!({
|
||||
"output_index": 0,
|
||||
"item": full_item.clone(),
|
||||
}),
|
||||
},
|
||||
ResponseStreamFrame {
|
||||
event_name: events::COMPLETED,
|
||||
data: json!({
|
||||
"response": response_shell(meta, status, &[full_item], None)
|
||||
}),
|
||||
},
|
||||
];
|
||||
for frame in frames {
|
||||
if tx.send(frame).await.is_err() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn response_shell(
|
||||
meta: &ResponseMeta,
|
||||
status: &str,
|
||||
output: &[Value],
|
||||
usage: Option<&ResponsesUsage>,
|
||||
) -> Value {
|
||||
let mut obj = serde_json::Map::new();
|
||||
obj.insert("id".into(), Value::String(meta.response_id.clone()));
|
||||
obj.insert("object".into(), Value::String("response".into()));
|
||||
obj.insert("created_at".into(), json!(meta.created_at));
|
||||
obj.insert("status".into(), Value::String(status.into()));
|
||||
obj.insert("model".into(), Value::String(meta.model_id.clone()));
|
||||
obj.insert("output".into(), Value::Array(output.to_vec()));
|
||||
if let Some(u) = usage {
|
||||
obj.insert(
|
||||
"usage".into(),
|
||||
json!({
|
||||
"input_tokens": u.input_tokens,
|
||||
"output_tokens": u.output_tokens,
|
||||
"total_tokens": u.total_tokens,
|
||||
}),
|
||||
);
|
||||
}
|
||||
Value::Object(obj)
|
||||
}
|
||||
|
||||
fn empty_message_item(item_id: &str) -> Value {
|
||||
json!({
|
||||
"type": "message",
|
||||
"id": item_id,
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"status": "in_progress",
|
||||
})
|
||||
}
|
||||
|
||||
fn finish_to_status(reason: FinishReason) -> &'static str {
|
||||
match reason {
|
||||
FinishReason::Stop | FinishReason::ToolCalls => "completed",
|
||||
FinishReason::Length => "incomplete",
|
||||
}
|
||||
}
|
||||
|
||||
// ── Non-streaming helpers ────────────────────────────────────────────
|
||||
|
||||
/// Collect a chat-completions response into a non-streaming
|
||||
/// [`ResponsesResponse`]. Used by the `/v1/responses` handler when
|
||||
/// the request doesn't set `stream: true`.
|
||||
pub fn build_response(
|
||||
meta: &ResponseMeta,
|
||||
full_text: String,
|
||||
reason: FinishReason,
|
||||
usage: Option<ResponsesUsage>,
|
||||
) -> ResponsesResponse {
|
||||
let status = finish_to_status(reason).to_string();
|
||||
ResponsesResponse {
|
||||
id: meta.response_id.clone(),
|
||||
object: "response".into(),
|
||||
created_at: meta.created_at,
|
||||
status: status.clone(),
|
||||
model: meta.model_id.clone(),
|
||||
output: vec![ResponsesOutputItem::Message {
|
||||
id: meta.message_item_id.clone(),
|
||||
role: "assistant".into(),
|
||||
content: vec![ResponsesOutputContent::OutputText {
|
||||
text: full_text,
|
||||
annotations: vec![],
|
||||
}],
|
||||
status,
|
||||
}],
|
||||
usage,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use cortex_core::openai::MessageContent;
|
||||
|
||||
fn meta() -> ResponseMeta {
|
||||
ResponseMeta {
|
||||
response_id: "resp_1".into(),
|
||||
created_at: 1700,
|
||||
model_id: "m".into(),
|
||||
message_item_id: "msg_1".into(),
|
||||
}
|
||||
}
|
||||
|
||||
// ── request translator ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn translates_text_input_to_single_user_message() {
|
||||
let req = ResponsesRequest {
|
||||
model: "m".into(),
|
||||
input: ResponsesInput::Text("hi".into()),
|
||||
instructions: None,
|
||||
stream: false,
|
||||
max_output_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
previous_response_id: None,
|
||||
extra: Value::Object(Default::default()),
|
||||
};
|
||||
let chat = request_to_chat(req).unwrap();
|
||||
assert_eq!(chat.messages.len(), 1);
|
||||
assert_eq!(chat.messages[0].role, "user");
|
||||
assert!(matches!(
|
||||
&chat.messages[0].content,
|
||||
MessageContent::Text(t) if t == "hi"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn instructions_become_leading_system_message() {
|
||||
let req = ResponsesRequest {
|
||||
model: "m".into(),
|
||||
input: ResponsesInput::Text("hi".into()),
|
||||
instructions: Some("you are helpful".into()),
|
||||
stream: false,
|
||||
max_output_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
previous_response_id: None,
|
||||
extra: Value::Object(Default::default()),
|
||||
};
|
||||
let chat = request_to_chat(req).unwrap();
|
||||
assert_eq!(chat.messages.len(), 2);
|
||||
assert_eq!(chat.messages[0].role, "system");
|
||||
assert!(matches!(
|
||||
&chat.messages[0].content,
|
||||
MessageContent::Text(t) if t == "you are helpful"
|
||||
));
|
||||
assert_eq!(chat.messages[1].role, "user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_previous_response_id() {
|
||||
let req = ResponsesRequest {
|
||||
model: "m".into(),
|
||||
input: ResponsesInput::Text("hi".into()),
|
||||
instructions: None,
|
||||
stream: false,
|
||||
max_output_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
previous_response_id: Some("resp_prev".into()),
|
||||
extra: Value::Object(Default::default()),
|
||||
};
|
||||
assert!(matches!(
|
||||
request_to_chat(req),
|
||||
Err(TranslateError::ChainedConversationNotSupported)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn translates_input_items_to_chat_messages() {
|
||||
let req = ResponsesRequest {
|
||||
model: "m".into(),
|
||||
input: ResponsesInput::Items(vec![
|
||||
ResponsesInputItem::Message {
|
||||
role: "user".into(),
|
||||
content: ResponsesMessageContent::Text("first".into()),
|
||||
},
|
||||
ResponsesInputItem::Message {
|
||||
role: "assistant".into(),
|
||||
content: ResponsesMessageContent::Text("reply".into()),
|
||||
},
|
||||
ResponsesInputItem::Message {
|
||||
role: "user".into(),
|
||||
content: ResponsesMessageContent::Text("second".into()),
|
||||
},
|
||||
]),
|
||||
instructions: None,
|
||||
stream: false,
|
||||
max_output_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
previous_response_id: None,
|
||||
extra: Value::Object(Default::default()),
|
||||
};
|
||||
let chat = request_to_chat(req).unwrap();
|
||||
assert_eq!(chat.messages.len(), 3);
|
||||
let roles: Vec<&str> = chat.messages.iter().map(|m| m.role.as_str()).collect();
|
||||
assert_eq!(roles, vec!["user", "assistant", "user"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn image_input_translates_to_chat_parts_array() {
|
||||
let req = ResponsesRequest {
|
||||
model: "m".into(),
|
||||
input: ResponsesInput::Items(vec![ResponsesInputItem::Message {
|
||||
role: "user".into(),
|
||||
content: ResponsesMessageContent::Parts(vec![
|
||||
ResponsesContentPart::InputText {
|
||||
text: "what is this?".into(),
|
||||
},
|
||||
ResponsesContentPart::InputImage {
|
||||
image_url: "data:image/png;base64,AAA=".into(),
|
||||
detail: None,
|
||||
},
|
||||
]),
|
||||
}]),
|
||||
instructions: None,
|
||||
stream: false,
|
||||
max_output_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
previous_response_id: None,
|
||||
extra: Value::Object(Default::default()),
|
||||
};
|
||||
let chat = request_to_chat(req).unwrap();
|
||||
let parts = match &chat.messages[0].content {
|
||||
MessageContent::Parts(p) => p.clone(),
|
||||
other => panic!("expected Parts, got {other:?}"),
|
||||
};
|
||||
assert_eq!(parts.len(), 2);
|
||||
assert_eq!(parts[0]["type"], "text");
|
||||
assert_eq!(parts[1]["type"], "image_url");
|
||||
assert_eq!(parts[1]["image_url"]["url"], "data:image/png;base64,AAA=");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn text_only_parts_collapse_to_string() {
|
||||
let req = ResponsesRequest {
|
||||
model: "m".into(),
|
||||
input: ResponsesInput::Items(vec![ResponsesInputItem::Message {
|
||||
role: "user".into(),
|
||||
content: ResponsesMessageContent::Parts(vec![
|
||||
ResponsesContentPart::InputText {
|
||||
text: "first".into(),
|
||||
},
|
||||
ResponsesContentPart::InputText {
|
||||
text: "second".into(),
|
||||
},
|
||||
]),
|
||||
}]),
|
||||
instructions: None,
|
||||
stream: false,
|
||||
max_output_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
previous_response_id: None,
|
||||
extra: Value::Object(Default::default()),
|
||||
};
|
||||
let chat = request_to_chat(req).unwrap();
|
||||
assert!(matches!(
|
||||
&chat.messages[0].content,
|
||||
MessageContent::Text(t) if t == "first\n\nsecond"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reasoning_items_are_silently_dropped() {
|
||||
let req = ResponsesRequest {
|
||||
model: "m".into(),
|
||||
input: ResponsesInput::Items(vec![
|
||||
ResponsesInputItem::Reasoning { content: vec![] },
|
||||
ResponsesInputItem::Message {
|
||||
role: "user".into(),
|
||||
content: ResponsesMessageContent::Text("hi".into()),
|
||||
},
|
||||
]),
|
||||
instructions: None,
|
||||
stream: false,
|
||||
max_output_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
previous_response_id: None,
|
||||
extra: Value::Object(Default::default()),
|
||||
};
|
||||
let chat = request_to_chat(req).unwrap();
|
||||
assert_eq!(chat.messages.len(), 1);
|
||||
assert_eq!(chat.messages[0].role, "user");
|
||||
}
|
||||
|
||||
// ── streaming projector ─────────────────────────────────────────
|
||||
|
||||
async fn collect(mut rx: mpsc::Receiver<ResponseStreamFrame>) -> Vec<ResponseStreamFrame> {
|
||||
let mut out = Vec::new();
|
||||
while let Some(f) = rx.recv().await {
|
||||
out.push(f);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn full_stream_emits_expected_event_sequence() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
|
||||
let out = project_responses_stream(rx, meta());
|
||||
|
||||
tx.send(InferenceEvent::Start).await.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta("hel".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta("lo".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Stop,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
let frames = collect(out).await;
|
||||
let names: Vec<&str> = frames.iter().map(|f| f.event_name).collect();
|
||||
assert_eq!(
|
||||
names,
|
||||
vec![
|
||||
events::CREATED,
|
||||
events::IN_PROGRESS,
|
||||
events::OUTPUT_ITEM_ADDED,
|
||||
events::CONTENT_PART_ADDED,
|
||||
events::OUTPUT_TEXT_DELTA,
|
||||
events::OUTPUT_TEXT_DELTA,
|
||||
events::OUTPUT_TEXT_DONE,
|
||||
events::CONTENT_PART_DONE,
|
||||
events::OUTPUT_ITEM_DONE,
|
||||
events::COMPLETED,
|
||||
]
|
||||
);
|
||||
|
||||
// The two deltas should carry the right text. Indices
|
||||
// shifted by one after IN_PROGRESS inserted between
|
||||
// CREATED and OUTPUT_ITEM_ADDED.
|
||||
assert_eq!(frames[4].data["delta"], "hel");
|
||||
assert_eq!(frames[5].data["delta"], "lo");
|
||||
|
||||
// The done event has the full accumulated text.
|
||||
assert_eq!(frames[6].data["text"], "hello");
|
||||
|
||||
// Completed event carries the full message item.
|
||||
let completed = &frames[9].data["response"];
|
||||
assert_eq!(completed["status"], "completed");
|
||||
let output = completed["output"].as_array().unwrap();
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0]["content"][0]["text"], "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn length_finish_maps_to_incomplete_status() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
|
||||
let out = project_responses_stream(rx, meta());
|
||||
tx.send(InferenceEvent::Start).await.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Length,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let frames = collect(out).await;
|
||||
let completed = frames
|
||||
.iter()
|
||||
.find(|f| f.event_name == events::COMPLETED)
|
||||
.unwrap();
|
||||
assert_eq!(completed.data["response"]["status"], "incomplete");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn synthesises_start_frames_when_producer_skips_start() {
|
||||
// A producer that drops without sending Start (poisoned
|
||||
// model, immediate disconnect, …) should still produce a
|
||||
// coherent stream — the projector synthesises the
|
||||
// mandatory header frames before COMPLETED so the
|
||||
// consumer never sees an output_text.done without a
|
||||
// matching content_part.added.
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
|
||||
let out = project_responses_stream(rx, meta());
|
||||
drop(tx);
|
||||
let frames = collect(out).await;
|
||||
let names: Vec<&str> = frames.iter().map(|f| f.event_name).collect();
|
||||
assert!(names.contains(&events::CREATED));
|
||||
assert!(names.contains(&events::COMPLETED));
|
||||
assert!(names.contains(&events::OUTPUT_TEXT_DONE));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_text_deltas_are_dropped() {
|
||||
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
|
||||
let out = project_responses_stream(rx, meta());
|
||||
tx.send(InferenceEvent::Start).await.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta(String::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::TextDelta("real".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(InferenceEvent::Finish {
|
||||
reason: FinishReason::Stop,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
let frames = collect(out).await;
|
||||
let delta_count = frames
|
||||
.iter()
|
||||
.filter(|f| f.event_name == events::OUTPUT_TEXT_DELTA)
|
||||
.count();
|
||||
assert_eq!(delta_count, 1, "empty delta must not produce a frame");
|
||||
}
|
||||
|
||||
// ── non-streaming builder ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn build_response_produces_completed_message_with_usage() {
|
||||
let r = build_response(
|
||||
&meta(),
|
||||
"hello".into(),
|
||||
FinishReason::Stop,
|
||||
Some(ResponsesUsage {
|
||||
input_tokens: 5,
|
||||
output_tokens: 1,
|
||||
total_tokens: 6,
|
||||
}),
|
||||
);
|
||||
assert_eq!(r.status, "completed");
|
||||
match &r.output[0] {
|
||||
ResponsesOutputItem::Message {
|
||||
role,
|
||||
content,
|
||||
status,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(role, "assistant");
|
||||
assert_eq!(status, "completed");
|
||||
match &content[0] {
|
||||
ResponsesOutputContent::OutputText { text, .. } => {
|
||||
assert_eq!(text, "hello");
|
||||
}
|
||||
}
|
||||
}
|
||||
other => panic!("expected Message, got {other:?}"),
|
||||
}
|
||||
let u = r.usage.unwrap();
|
||||
assert_eq!(u.total_tokens, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_response_length_yields_incomplete_status() {
|
||||
let r = build_response(&meta(), "trunc".into(), FinishReason::Length, None);
|
||||
assert_eq!(r.status, "incomplete");
|
||||
}
|
||||
}
|
||||
77
crates/neuron/tests/activation.rs
Normal file
77
crates/neuron/tests/activation.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
//! Activation-time behaviour: load_default_models continues past
|
||||
//! individual failures so a single broken catalogue entry doesn't
|
||||
//! prevent the rest of the fleet from starting.
|
||||
|
||||
use cortex_core::discovery::ActivationState;
|
||||
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||
use neuron::activation::ActivationTracker;
|
||||
use neuron::config::HarnessSettings;
|
||||
use neuron::harness::HarnessRegistry;
|
||||
use neuron::startup;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_default_models_skips_unknown_harness() {
|
||||
let registry = HarnessRegistry::from_configs(
|
||||
&[HarnessConfig {
|
||||
name: "candle".into(),
|
||||
}],
|
||||
"http://localhost:0",
|
||||
&HarnessSettings::default(),
|
||||
);
|
||||
|
||||
// Both entries fail synchronously inside the registry — no network
|
||||
// call escapes (the harness lookup mismatches before hf-hub is
|
||||
// touched). The function should still return cleanly.
|
||||
let specs = vec![
|
||||
ModelSpec {
|
||||
model_id: "model-a".into(),
|
||||
harness: "no-such-harness".into(),
|
||||
quant: None,
|
||||
tensor_parallel: None,
|
||||
devices: None,
|
||||
},
|
||||
ModelSpec {
|
||||
model_id: "model-b".into(),
|
||||
harness: "no-such-harness".into(),
|
||||
quant: None,
|
||||
tensor_parallel: None,
|
||||
devices: None,
|
||||
},
|
||||
];
|
||||
|
||||
let activation = ActivationTracker::new(&specs);
|
||||
startup::load_default_models(®istry, &specs, &activation).await;
|
||||
|
||||
let listed = registry
|
||||
.list_all_models()
|
||||
.await
|
||||
.expect("list_all_models should succeed");
|
||||
assert!(
|
||||
listed.is_empty(),
|
||||
"no models should be loaded after failed entries"
|
||||
);
|
||||
|
||||
// Both specs should land in `failed`; tracker should flip to ready.
|
||||
let snapshot = activation.snapshot().await;
|
||||
assert_eq!(snapshot.state, ActivationState::Ready);
|
||||
assert!(snapshot.pending.is_empty());
|
||||
assert!(snapshot.in_progress.is_none());
|
||||
assert!(snapshot.completed.is_empty());
|
||||
assert_eq!(snapshot.failed.len(), 2);
|
||||
let failed_ids: Vec<&str> = snapshot
|
||||
.failed
|
||||
.iter()
|
||||
.map(|f| f.model_id.as_str())
|
||||
.collect();
|
||||
assert!(failed_ids.contains(&"model-a"));
|
||||
assert!(failed_ids.contains(&"model-b"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_load_default_models_empty_is_noop() {
|
||||
let registry = HarnessRegistry::new();
|
||||
let activation = ActivationTracker::new(&[]);
|
||||
startup::load_default_models(®istry, &[], &activation).await;
|
||||
let snapshot = activation.snapshot().await;
|
||||
assert_eq!(snapshot.state, ActivationState::Ready);
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use cortex_core::discovery::{DeviceInfo, DiscoveryResponse};
|
||||
use neuron::activation::ActivationTracker;
|
||||
use neuron::api::{self, NeuronState};
|
||||
use neuron::harness::HarnessRegistry;
|
||||
use neuron::health::HealthCache;
|
||||
@@ -14,6 +15,8 @@ async fn spawn_neuron(discovery: DiscoveryResponse) -> String {
|
||||
discovery,
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle: None,
|
||||
activation: Arc::new(ActivationTracker::new(&[])),
|
||||
});
|
||||
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
@@ -135,56 +138,31 @@ async fn test_models_empty_registry() {
|
||||
assert!(body.as_array().unwrap().is_empty());
|
||||
}
|
||||
|
||||
/// Spawn a mock mistral.rs backend and a neuron with the mistralrs harness
|
||||
/// pointing at it, then test the full model lifecycle through neuron's API.
|
||||
/// Verify the candle harness registers, list is empty by default, and a
|
||||
/// load attempt for an obviously-bogus model id returns a 4xx error
|
||||
/// without crashing the daemon. Real load/unload exercising actual GGUF
|
||||
/// download is covered by `tests/candle_lifecycle.rs` (cuda-integration).
|
||||
#[tokio::test]
|
||||
async fn test_models_via_mistralrs_harness() {
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
async fn test_candle_harness_registers_and_rejects_bogus_model() {
|
||||
use cortex_core::harness::HarnessConfig;
|
||||
use serde_json::Value;
|
||||
use neuron::config::HarnessSettings;
|
||||
|
||||
// Mock mistral.rs backend.
|
||||
let mock_app = Router::new()
|
||||
.route(
|
||||
"/v1/models",
|
||||
get(|| async {
|
||||
Json(json!({
|
||||
"data": [
|
||||
{"id": "test-model", "status": "loaded"},
|
||||
{"id": "other-model", "status": "unloaded"}
|
||||
]
|
||||
}))
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/v1/models/unload",
|
||||
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
|
||||
)
|
||||
.route(
|
||||
"/v1/models/reload",
|
||||
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
|
||||
);
|
||||
|
||||
let mock_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let mock_addr = mock_listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(mock_listener, mock_app).await.unwrap();
|
||||
});
|
||||
let mock_url = format!("http://{mock_addr}");
|
||||
|
||||
// Build neuron with mistralrs harness pointing at mock.
|
||||
let registry = HarnessRegistry::from_configs(&[HarnessConfig {
|
||||
name: "mistralrs".into(),
|
||||
endpoint: Some(mock_url.clone()),
|
||||
systemd_unit: None,
|
||||
}]);
|
||||
let registry = HarnessRegistry::from_configs(
|
||||
&[HarnessConfig {
|
||||
name: "candle".into(),
|
||||
}],
|
||||
"http://localhost:13131",
|
||||
&HarnessSettings::default(),
|
||||
);
|
||||
|
||||
let candle = registry.candle();
|
||||
let health_cache = Arc::new(HealthCache::new());
|
||||
let state = Arc::new(NeuronState {
|
||||
discovery: fake_discovery(),
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle,
|
||||
activation: Arc::new(ActivationTracker::new(&[])),
|
||||
});
|
||||
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
@@ -197,7 +175,6 @@ async fn test_models_via_mistralrs_harness() {
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// GET /models — should return models from mock mistralrs.
|
||||
let resp = client
|
||||
.get(format!("{neuron_url}/models"))
|
||||
.send()
|
||||
@@ -205,45 +182,308 @@ async fn test_models_via_mistralrs_harness() {
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(models.len(), 2);
|
||||
assert_eq!(models[0]["id"], "test-model");
|
||||
assert_eq!(models[0]["harness"], "mistralrs");
|
||||
assert_eq!(models[0]["status"], "loaded");
|
||||
assert_eq!(models[1]["id"], "other-model");
|
||||
assert_eq!(models[1]["status"], "unloaded");
|
||||
assert!(models.is_empty());
|
||||
|
||||
// GET /models/test-model/endpoint — should return mock URL.
|
||||
let resp = client
|
||||
.get(format!("{neuron_url}/models/test-model/endpoint"))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["url"], mock_url);
|
||||
|
||||
// POST /models/unload — should succeed.
|
||||
let resp = client
|
||||
.post(format!("{neuron_url}/models/unload"))
|
||||
.json(&json!({"model_id": "test-model"}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["status"], "unloaded");
|
||||
|
||||
// POST /models/load — should succeed.
|
||||
// Sending a wrong-harness spec should be rejected synchronously
|
||||
// without touching the network or the model registry.
|
||||
let resp = client
|
||||
.post(format!("{neuron_url}/models/load"))
|
||||
.json(&json!({"model_id": "definitely/not-real", "harness": "not-candle"}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 400);
|
||||
|
||||
// Registry still empty.
|
||||
let resp = client
|
||||
.get(format!("{neuron_url}/models"))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert!(models.is_empty());
|
||||
}
|
||||
|
||||
/// `/v1/chat/completions` returns 503 when no candle harness is registered.
|
||||
#[tokio::test]
|
||||
async fn test_chat_completions_no_candle_harness() {
|
||||
let registry = HarnessRegistry::new();
|
||||
let health_cache = Arc::new(HealthCache::new());
|
||||
let state = Arc::new(NeuronState {
|
||||
discovery: fake_discovery(),
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle: None,
|
||||
activation: Arc::new(ActivationTracker::new(&[])),
|
||||
});
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let url = format!("http://{addr}");
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{url}/v1/chat/completions"))
|
||||
.json(&json!({
|
||||
"model_id": "test-model",
|
||||
"harness": "mistralrs"
|
||||
"model": "anything",
|
||||
"messages": [{"role": "user", "content": "hi"}]
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["status"], "loaded");
|
||||
assert_eq!(resp.status(), 503);
|
||||
}
|
||||
|
||||
/// `/v1/chat/completions` returns 404 when the requested model isn't loaded.
|
||||
#[tokio::test]
|
||||
async fn test_chat_completions_model_not_loaded() {
|
||||
use cortex_core::harness::HarnessConfig;
|
||||
use neuron::config::HarnessSettings;
|
||||
|
||||
let registry = HarnessRegistry::from_configs(
|
||||
&[HarnessConfig {
|
||||
name: "candle".into(),
|
||||
}],
|
||||
"http://localhost:0",
|
||||
&HarnessSettings::default(),
|
||||
);
|
||||
let candle = registry.candle();
|
||||
let health_cache = Arc::new(HealthCache::new());
|
||||
let state = Arc::new(NeuronState {
|
||||
discovery: fake_discovery(),
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle,
|
||||
activation: Arc::new(ActivationTracker::new(&[])),
|
||||
});
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let url = format!("http://{addr}");
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{url}/v1/chat/completions"))
|
||||
.json(&json!({
|
||||
"model": "definitely/not-loaded",
|
||||
"messages": [{"role": "user", "content": "hi"}]
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 404);
|
||||
}
|
||||
|
||||
/// `/v1/chat/completions` with `stream: true` returns 404 when the
|
||||
/// model isn't loaded — same surface as the non-streaming path. The
|
||||
/// streaming code only kicks in once the model lookup succeeds.
|
||||
#[tokio::test]
|
||||
async fn test_chat_completions_streaming_model_not_loaded() {
|
||||
use cortex_core::harness::HarnessConfig;
|
||||
use neuron::config::HarnessSettings;
|
||||
|
||||
let registry = HarnessRegistry::from_configs(
|
||||
&[HarnessConfig {
|
||||
name: "candle".into(),
|
||||
}],
|
||||
"http://localhost:0",
|
||||
&HarnessSettings::default(),
|
||||
);
|
||||
let candle = registry.candle();
|
||||
let health_cache = Arc::new(HealthCache::new());
|
||||
let state = Arc::new(NeuronState {
|
||||
discovery: fake_discovery(),
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle,
|
||||
activation: Arc::new(ActivationTracker::new(&[])),
|
||||
});
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let url = format!("http://{addr}");
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{url}/v1/chat/completions"))
|
||||
.json(&json!({
|
||||
"model": "definitely/not-loaded",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 404);
|
||||
}
|
||||
|
||||
// ── /v1/responses ────────────────────────────────────────────────────
|
||||
|
||||
/// `/v1/responses` returns 503 when no candle harness is registered —
|
||||
/// matches the chat-completions error shape so a client can swap
|
||||
/// endpoints without re-handling 503s.
|
||||
#[tokio::test]
|
||||
async fn test_responses_no_candle_harness() {
|
||||
let registry = HarnessRegistry::new();
|
||||
let health_cache = Arc::new(HealthCache::new());
|
||||
let state = Arc::new(NeuronState {
|
||||
discovery: fake_discovery(),
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle: None,
|
||||
activation: Arc::new(ActivationTracker::new(&[])),
|
||||
});
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let url = format!("http://{addr}");
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{url}/v1/responses"))
|
||||
.json(&json!({"model": "anything", "input": "hi"}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 503);
|
||||
}
|
||||
|
||||
/// `previous_response_id` is rejected at translate time with 400 —
|
||||
/// we don't store responses server-side yet, so chained
|
||||
/// conversations can't be honoured.
|
||||
#[tokio::test]
|
||||
async fn test_responses_rejects_previous_response_id() {
|
||||
use cortex_core::harness::HarnessConfig;
|
||||
use neuron::config::HarnessSettings;
|
||||
|
||||
let registry = HarnessRegistry::from_configs(
|
||||
&[HarnessConfig {
|
||||
name: "candle".into(),
|
||||
}],
|
||||
"http://localhost:0",
|
||||
&HarnessSettings::default(),
|
||||
);
|
||||
let candle = registry.candle();
|
||||
let health_cache = Arc::new(HealthCache::new());
|
||||
let state = Arc::new(NeuronState {
|
||||
discovery: fake_discovery(),
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle,
|
||||
activation: Arc::new(ActivationTracker::new(&[])),
|
||||
});
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let url = format!("http://{addr}");
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{url}/v1/responses"))
|
||||
.json(&json!({
|
||||
"model": "anything",
|
||||
"input": "hi",
|
||||
"previous_response_id": "resp_prev_42"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 400);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["code"], "chained_conversation_not_supported");
|
||||
}
|
||||
|
||||
/// `/v1/responses` returns 404 when the model isn't loaded — same
|
||||
/// surface as chat completions.
|
||||
#[tokio::test]
|
||||
async fn test_responses_model_not_loaded() {
|
||||
use cortex_core::harness::HarnessConfig;
|
||||
use neuron::config::HarnessSettings;
|
||||
|
||||
let registry = HarnessRegistry::from_configs(
|
||||
&[HarnessConfig {
|
||||
name: "candle".into(),
|
||||
}],
|
||||
"http://localhost:0",
|
||||
&HarnessSettings::default(),
|
||||
);
|
||||
let candle = registry.candle();
|
||||
let health_cache = Arc::new(HealthCache::new());
|
||||
let state = Arc::new(NeuronState {
|
||||
discovery: fake_discovery(),
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle,
|
||||
activation: Arc::new(ActivationTracker::new(&[])),
|
||||
});
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let url = format!("http://{addr}");
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{url}/v1/responses"))
|
||||
.json(&json!({"model": "not-loaded", "input": "hi"}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 404);
|
||||
}
|
||||
|
||||
/// Same model-not-loaded surface on the streaming path. The
|
||||
/// stream is opened only after model lookup succeeds, so a
|
||||
/// missing model fails fast with a non-SSE 404 response.
|
||||
#[tokio::test]
|
||||
async fn test_responses_streaming_model_not_loaded() {
|
||||
use cortex_core::harness::HarnessConfig;
|
||||
use neuron::config::HarnessSettings;
|
||||
|
||||
let registry = HarnessRegistry::from_configs(
|
||||
&[HarnessConfig {
|
||||
name: "candle".into(),
|
||||
}],
|
||||
"http://localhost:0",
|
||||
&HarnessSettings::default(),
|
||||
);
|
||||
let candle = registry.candle();
|
||||
let health_cache = Arc::new(HealthCache::new());
|
||||
let state = Arc::new(NeuronState {
|
||||
discovery: fake_discovery(),
|
||||
health_cache,
|
||||
registry: RwLock::new(registry),
|
||||
candle,
|
||||
activation: Arc::new(ActivationTracker::new(&[])),
|
||||
});
|
||||
let app = api::neuron_routes().with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let url = format!("http://{addr}");
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{url}/v1/responses"))
|
||||
.json(&json!({
|
||||
"model": "not-loaded",
|
||||
"input": "hi",
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 404);
|
||||
}
|
||||
|
||||
87
crates/neuron/tests/candle_lifecycle.rs
Normal file
87
crates/neuron/tests/candle_lifecycle.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
//! Real model load/unload lifecycle through the candle harness.
|
||||
//!
|
||||
//! Gated behind the `cuda-integration` feature because it downloads a
|
||||
//! real (small) GGUF from HuggingFace and materialises tensors on the
|
||||
//! configured device. Run on a host with network access and either a
|
||||
//! CUDA GPU (when built with `--features cuda`) or enough CPU RAM to
|
||||
//! hold the model.
|
||||
//!
|
||||
//! Usage:
|
||||
//! cargo test -p neuron --features cuda-integration --test candle_lifecycle
|
||||
//!
|
||||
//! Optional environment variables:
|
||||
//! NEURON_TEST_MODEL_ID — HuggingFace repo to load (default: a small
|
||||
//! public Qwen3 GGUF repo).
|
||||
//! NEURON_TEST_QUANT — quant substring matched against GGUF
|
||||
//! filenames (default: "Q4_K_M").
|
||||
//! HF_HOME — HuggingFace cache directory.
|
||||
|
||||
#![cfg(feature = "cuda-integration")]
|
||||
|
||||
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||
use neuron::config::HarnessSettings;
|
||||
use neuron::harness::HarnessRegistry;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_candle_qwen3_load_unload_lifecycle() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.with_env_filter("info,neuron=debug")
|
||||
.try_init();
|
||||
|
||||
let model_id = std::env::var("NEURON_TEST_MODEL_ID")
|
||||
.unwrap_or_else(|_| "Qwen/Qwen3-0.6B-GGUF".to_string());
|
||||
let quant = std::env::var("NEURON_TEST_QUANT").unwrap_or_else(|_| "Q4_K_M".to_string());
|
||||
|
||||
let mut settings = HarnessSettings::default();
|
||||
if let Ok(home) = std::env::var("HF_HOME") {
|
||||
settings.candle.hf_cache = Some(PathBuf::from(home));
|
||||
}
|
||||
|
||||
let registry = HarnessRegistry::from_configs(
|
||||
&[HarnessConfig {
|
||||
name: "candle".into(),
|
||||
}],
|
||||
"http://localhost:13131",
|
||||
&settings,
|
||||
);
|
||||
|
||||
let spec = ModelSpec {
|
||||
model_id: model_id.clone(),
|
||||
harness: "candle".into(),
|
||||
quant: Some(quant),
|
||||
tensor_parallel: None,
|
||||
devices: Some(vec![0]),
|
||||
};
|
||||
|
||||
registry
|
||||
.load_model(&spec)
|
||||
.await
|
||||
.expect("load_model should succeed");
|
||||
|
||||
let models = registry.list_all_models().await.expect("list_all_models");
|
||||
assert_eq!(models.len(), 1, "expected exactly one loaded model");
|
||||
assert_eq!(models[0].id, model_id);
|
||||
assert_eq!(models[0].harness, "candle");
|
||||
assert_eq!(models[0].status, "loaded");
|
||||
|
||||
let url = registry.inference_endpoint(&model_id).await;
|
||||
assert_eq!(url, Some("http://localhost:13131".into()));
|
||||
|
||||
// Re-loading the same model should be rejected.
|
||||
let again = registry.load_model(&spec).await;
|
||||
assert!(again.is_err(), "second load should error");
|
||||
|
||||
registry
|
||||
.unload_model(&model_id)
|
||||
.await
|
||||
.expect("unload_model should succeed");
|
||||
|
||||
let models = registry.list_all_models().await.expect("list_all_models");
|
||||
assert!(models.is_empty(), "registry should be empty after unload");
|
||||
|
||||
// Unloading a model that isn't loaded should error.
|
||||
let err = registry.unload_model(&model_id).await;
|
||||
assert!(err.is_err(), "unload of missing model should error");
|
||||
}
|
||||
269
crates/neuron/tests/preflight.rs
Normal file
269
crates/neuron/tests/preflight.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
//! End-to-end preflight tests against a mock HF-compatible server.
|
||||
//!
|
||||
//! Unit tests in `harness/preflight.rs` exercise the classifier and
|
||||
//! feasibility table against synthetic file lists. These tests close
|
||||
//! the loop: spawn an axum server that returns a `RepoInfo`-shaped
|
||||
//! JSON payload at `/api/models/{org}/{name}`, point `hf_hub::Api` at
|
||||
//! it, and assert `preflight()` returns the expected outcome.
|
||||
|
||||
use axum::Router;
|
||||
use axum::extract::Path;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::routing::get;
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use neuron::harness::preflight::{PreflightError, SourceFormat, preflight};
|
||||
use serde_json::{Value, json};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Per-test mock state: a map from `{org}/{name}` to the JSON body the
|
||||
/// mock server returns at the corresponding `/api/models/{org}/{name}`
|
||||
/// endpoint. `None` means "respond 404".
|
||||
type MockBodies = Arc<Mutex<std::collections::HashMap<String, Option<Value>>>>;
|
||||
|
||||
async fn spawn_mock(bodies: MockBodies) -> String {
|
||||
// hf-hub 0.4 calls /api/models/{org}/{name}/revision/main for
|
||||
// `repo.info()`. We route both shapes so the test stays robust
|
||||
// to a future hf-hub upgrade that drops the `/revision/main`
|
||||
// suffix.
|
||||
let app = Router::new()
|
||||
.route("/api/models/{org}/{name}", get(model_info))
|
||||
.route(
|
||||
"/api/models/{org}/{name}/revision/{rev}",
|
||||
get(model_info_rev),
|
||||
)
|
||||
.with_state(bodies);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
format!("http://{addr}")
|
||||
}
|
||||
|
||||
async fn model_info(
|
||||
Path((org, name)): Path<(String, String)>,
|
||||
axum::extract::State(bodies): axum::extract::State<MockBodies>,
|
||||
) -> impl IntoResponse {
|
||||
respond(&format!("{org}/{name}"), &bodies)
|
||||
}
|
||||
|
||||
async fn model_info_rev(
|
||||
Path((org, name, _rev)): Path<(String, String, String)>,
|
||||
axum::extract::State(bodies): axum::extract::State<MockBodies>,
|
||||
) -> impl IntoResponse {
|
||||
respond(&format!("{org}/{name}"), &bodies)
|
||||
}
|
||||
|
||||
fn respond(key: &str, bodies: &MockBodies) -> axum::response::Response {
|
||||
let entry = bodies.lock().unwrap().get(key).cloned();
|
||||
match entry {
|
||||
Some(Some(body)) => Json(body).into_response(),
|
||||
Some(None) | None => (StatusCode::NOT_FOUND, "not found").into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_api(endpoint: &str, cache_dir: &std::path::Path) -> hf_hub::api::tokio::Api {
|
||||
hf_hub::api::tokio::ApiBuilder::new()
|
||||
.with_endpoint(endpoint.to_string())
|
||||
.with_cache_dir(cache_dir.to_path_buf())
|
||||
.build()
|
||||
.expect("build hf-hub Api")
|
||||
}
|
||||
|
||||
fn siblings(filenames: &[&str]) -> Value {
|
||||
json!({
|
||||
"sha": "0000000000000000000000000000000000000000",
|
||||
"siblings": filenames.iter().map(|f| json!({ "rfilename": f })).collect::<Vec<_>>(),
|
||||
})
|
||||
}
|
||||
|
||||
fn spec(model_id: &str, tp: Option<u32>, quant: Option<&str>) -> ModelSpec {
|
||||
ModelSpec {
|
||||
model_id: model_id.into(),
|
||||
harness: "candle".into(),
|
||||
quant: quant.map(String::from),
|
||||
tensor_parallel: tp,
|
||||
devices: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preflight_gguf_tp_rejected_over_http() {
|
||||
let cache = tempfile::tempdir().expect("tempdir");
|
||||
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
|
||||
bodies.lock().unwrap().insert(
|
||||
"HauhauCS/Qwen3.6".to_string(),
|
||||
Some(siblings(&[
|
||||
"README.md",
|
||||
".gitattributes",
|
||||
"Qwen3.6-Q4_K_P.gguf",
|
||||
"Qwen3.6-Q6_K_P.gguf",
|
||||
"Qwen3.6-Q8_K_P.gguf",
|
||||
])),
|
||||
);
|
||||
let endpoint = spawn_mock(bodies).await;
|
||||
|
||||
let api = build_api(&endpoint, cache.path());
|
||||
let s = spec("HauhauCS/Qwen3.6", Some(2), Some("q6k"));
|
||||
let err = preflight(&api, &s).await.unwrap_err();
|
||||
match err {
|
||||
PreflightError::TpRequiresSafetensors {
|
||||
model_id,
|
||||
tp_size,
|
||||
gguf_quants,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(model_id, "HauhauCS/Qwen3.6");
|
||||
assert_eq!(tp_size, 2);
|
||||
assert_eq!(gguf_quants.len(), 3);
|
||||
}
|
||||
other => panic!("expected TpRequiresSafetensors, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preflight_gguf_quant_suggestion_over_http() {
|
||||
let cache = tempfile::tempdir().expect("tempdir");
|
||||
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
|
||||
bodies.lock().unwrap().insert(
|
||||
"HauhauCS/Qwen3.6".to_string(),
|
||||
Some(siblings(&[
|
||||
"Qwen3.6-Q4_K_P.gguf",
|
||||
"Qwen3.6-Q5_K_P.gguf",
|
||||
"Qwen3.6-Q6_K_P.gguf",
|
||||
"Qwen3.6-Q8_K_P.gguf",
|
||||
])),
|
||||
);
|
||||
let endpoint = spawn_mock(bodies).await;
|
||||
|
||||
let api = build_api(&endpoint, cache.path());
|
||||
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6k"));
|
||||
let err = preflight(&api, &s).await.unwrap_err();
|
||||
match err {
|
||||
PreflightError::QuantNotFound {
|
||||
requested,
|
||||
nearest,
|
||||
available,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(requested, "q6k");
|
||||
assert_eq!(nearest.as_deref(), Some("q6_k_p"));
|
||||
assert_eq!(available.len(), 4);
|
||||
}
|
||||
other => panic!("expected QuantNotFound, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preflight_dense_safetensors_tp_ok() {
|
||||
let cache = tempfile::tempdir().expect("tempdir");
|
||||
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
|
||||
bodies.lock().unwrap().insert(
|
||||
"Qwen/Q3-30B".to_string(),
|
||||
Some(siblings(&[
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"model.safetensors.index.json",
|
||||
"model-00001-of-00006.safetensors",
|
||||
"model-00002-of-00006.safetensors",
|
||||
"model-00003-of-00006.safetensors",
|
||||
])),
|
||||
);
|
||||
let endpoint = spawn_mock(bodies).await;
|
||||
|
||||
let api = build_api(&endpoint, cache.path());
|
||||
let s = spec("Qwen/Q3-30B", Some(2), Some("q5k"));
|
||||
let plan = preflight(&api, &s).await.expect("dense+tp should succeed");
|
||||
assert_eq!(plan.tp_size, 2);
|
||||
assert!(plan.picked_quant_file.is_none());
|
||||
assert!(matches!(
|
||||
plan.format,
|
||||
SourceFormat::DenseSafetensors { sharded: true }
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preflight_gguf_single_gpu_good_quant() {
|
||||
let cache = tempfile::tempdir().expect("tempdir");
|
||||
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
|
||||
bodies.lock().unwrap().insert(
|
||||
"HauhauCS/Qwen3.6".to_string(),
|
||||
Some(siblings(&["Qwen3.6-Q4_K_P.gguf", "Qwen3.6-Q6_K_P.gguf"])),
|
||||
);
|
||||
let endpoint = spawn_mock(bodies).await;
|
||||
|
||||
let api = build_api(&endpoint, cache.path());
|
||||
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6_k_p"));
|
||||
let plan = preflight(&api, &s)
|
||||
.await
|
||||
.expect("good quant should succeed");
|
||||
assert_eq!(plan.tp_size, 1);
|
||||
assert_eq!(
|
||||
plan.picked_quant_file.as_deref(),
|
||||
Some("Qwen3.6-Q6_K_P.gguf")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preflight_repo_fetch_failed_on_404() {
|
||||
// Mock server has no entry for this id → 404, exercising the
|
||||
// RepoFetchFailed path (the same shape today's HauhauCS scenario
|
||||
// would have produced if we'd added preflight before the cache
|
||||
// download was attempted).
|
||||
let cache = tempfile::tempdir().expect("tempdir");
|
||||
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
|
||||
let endpoint = spawn_mock(bodies).await;
|
||||
|
||||
let api = build_api(&endpoint, cache.path());
|
||||
let s = spec("DoesNot/Exist", Some(1), None);
|
||||
let err = preflight(&api, &s).await.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, PreflightError::RepoFetchFailed { .. }),
|
||||
"expected RepoFetchFailed, got {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preflight_empty_repo_rejected() {
|
||||
let cache = tempfile::tempdir().expect("tempdir");
|
||||
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
|
||||
bodies.lock().unwrap().insert(
|
||||
"Empty/Repo".to_string(),
|
||||
Some(siblings(&["README.md", "tokenizer.json"])),
|
||||
);
|
||||
let endpoint = spawn_mock(bodies).await;
|
||||
|
||||
let api = build_api(&endpoint, cache.path());
|
||||
let s = spec("Empty/Repo", Some(1), None);
|
||||
let err = preflight(&api, &s).await.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, PreflightError::EmptyRepo { .. }),
|
||||
"expected EmptyRepo, got {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preflight_mixed_repo_prefers_safetensors() {
|
||||
let cache = tempfile::tempdir().expect("tempdir");
|
||||
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
|
||||
bodies.lock().unwrap().insert(
|
||||
"Mixed/Repo".to_string(),
|
||||
Some(siblings(&[
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"model.safetensors",
|
||||
"model-Q4_K_M.gguf",
|
||||
])),
|
||||
);
|
||||
let endpoint = spawn_mock(bodies).await;
|
||||
|
||||
let api = build_api(&endpoint, cache.path());
|
||||
// TP=2 + quant should succeed via the dense path even though a
|
||||
// GGUF is present — the dense path handles ISQ.
|
||||
let s = spec("Mixed/Repo", Some(2), Some("q5k"));
|
||||
let plan = preflight(&api, &s).await.expect("mixed should succeed");
|
||||
assert!(matches!(plan.format, SourceFormat::Mixed { .. }));
|
||||
}
|
||||
32
crates/neuron/tests/shutdown.rs
Normal file
32
crates/neuron/tests/shutdown.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
//! Deactivation behaviour: unload_all_models tolerates an empty
|
||||
//! registry and continues past per-model unload failures.
|
||||
|
||||
use cortex_core::harness::HarnessConfig;
|
||||
use neuron::config::HarnessSettings;
|
||||
use neuron::harness::HarnessRegistry;
|
||||
use neuron::startup;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unload_all_models_empty_registry_is_noop() {
|
||||
let registry = HarnessRegistry::new();
|
||||
startup::unload_all_models(®istry).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unload_all_models_with_no_loaded_models() {
|
||||
let registry = HarnessRegistry::from_configs(
|
||||
&[HarnessConfig {
|
||||
name: "candle".into(),
|
||||
}],
|
||||
"http://localhost:0",
|
||||
&HarnessSettings::default(),
|
||||
);
|
||||
|
||||
startup::unload_all_models(®istry).await;
|
||||
|
||||
let listed = registry
|
||||
.list_all_models()
|
||||
.await
|
||||
.expect("list_all_models should still succeed after shutdown cleanup");
|
||||
assert!(listed.is_empty());
|
||||
}
|
||||
148
crates/neuron/tests/tp_worker_lifecycle.rs
Normal file
148
crates/neuron/tests/tp_worker_lifecycle.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
//! Stage 7a-i: confirm the TP worker subprocess lifecycle round-trips.
|
||||
//!
|
||||
//! Spawns two worker subprocesses via the leader→worker stdio RPC,
|
||||
//! pings each, and cleanly shuts them down. No CUDA required —
|
||||
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
|
||||
//! runs on any host the workspace builds on.
|
||||
|
||||
use neuron::harness::device_worker::DeviceWorkerHandle;
|
||||
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
|
||||
|
||||
/// Path to the neuron binary built by cargo for this test process.
|
||||
/// cargo populates `CARGO_BIN_EXE_neuron` at compile time for sibling-
|
||||
/// binary tests; production paths in main.rs use `/proc/self/exe`.
|
||||
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
|
||||
|
||||
/// Two workers (so we spawn one subprocess: rank 0 is in-process,
|
||||
/// rank 1 is the child). Verify the spawned worker responds to Ping
|
||||
/// with its own identity, then shut it down cleanly.
|
||||
#[tokio::test]
|
||||
async fn test_spawn_ping_shutdown() {
|
||||
// cuda_devices: rank 0 → device 0 (leader, unused here),
|
||||
// rank 1 → device 1 (worker; not actually opened in 7a-i).
|
||||
let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker");
|
||||
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker)
|
||||
.await
|
||||
.expect("spawn worker pool");
|
||||
|
||||
let pongs = pool.ping_all().await.expect("ping all workers");
|
||||
assert_eq!(pongs.len(), 1, "expected one Pong (rank 1 only)");
|
||||
match &pongs[0] {
|
||||
WorkerResponse::Pong {
|
||||
rank,
|
||||
world_size,
|
||||
cuda_device,
|
||||
} => {
|
||||
assert_eq!(*rank, 1);
|
||||
assert_eq!(*world_size, 2);
|
||||
assert_eq!(*cuda_device, 1);
|
||||
}
|
||||
other => panic!("expected Pong, got {other:?}"),
|
||||
}
|
||||
|
||||
pool.shutdown().await.expect("clean shutdown");
|
||||
}
|
||||
|
||||
/// Three workers — exercise the loop in `ping_all` / `shutdown`.
|
||||
#[tokio::test]
|
||||
async fn test_spawn_three_workers() {
|
||||
let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker");
|
||||
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2], leader_worker)
|
||||
.await
|
||||
.expect("spawn worker pool");
|
||||
|
||||
let pongs = pool.ping_all().await.expect("ping all workers");
|
||||
assert_eq!(pongs.len(), 2, "expected two Pongs (ranks 1 and 2)");
|
||||
for (i, resp) in pongs.iter().enumerate() {
|
||||
match resp {
|
||||
WorkerResponse::Pong {
|
||||
rank,
|
||||
world_size,
|
||||
cuda_device,
|
||||
} => {
|
||||
let expected_rank = (i + 1) as u32;
|
||||
assert_eq!(*rank, expected_rank);
|
||||
assert_eq!(*world_size, 3);
|
||||
assert_eq!(*cuda_device, expected_rank);
|
||||
}
|
||||
other => panic!("expected Pong, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pool.shutdown().await.expect("clean shutdown");
|
||||
}
|
||||
|
||||
/// 7a-ii: without the cuda feature, Init must fail with a clear
|
||||
/// `cuda_feature_not_enabled` marker rather than silently succeeding.
|
||||
/// This is the local-dev-box test; the real NCCL handshake is exercised
|
||||
/// by `tp_worker_lifecycle_cuda.rs` (gated on `cuda-integration`).
|
||||
#[tokio::test]
|
||||
async fn test_init_returns_cuda_feature_not_enabled_without_cuda() {
|
||||
use neuron::harness::tp::rpc::WorkerRequest;
|
||||
use std::process::Stdio;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
|
||||
// Spawn a single worker by hand to send Init directly (the pool's
|
||||
// public API doesn't expose Init yet — that lands in 7a-ii).
|
||||
let mut child = Command::new(NEURON_BIN)
|
||||
.arg("--worker")
|
||||
.arg("--rank")
|
||||
.arg("1")
|
||||
.arg("--tp-size")
|
||||
.arg("2")
|
||||
.arg("--cuda-device")
|
||||
.arg("1")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::null())
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.expect("spawn worker");
|
||||
|
||||
let mut stdin = child.stdin.take().expect("stdin");
|
||||
let stdout = child.stdout.take().expect("stdout");
|
||||
let mut lines = BufReader::new(stdout).lines();
|
||||
|
||||
let req = WorkerRequest::Init {
|
||||
comm_id: "ff".repeat(128),
|
||||
};
|
||||
let mut payload = serde_json::to_string(&req).unwrap();
|
||||
payload.push('\n');
|
||||
stdin.write_all(payload.as_bytes()).await.unwrap();
|
||||
stdin.flush().await.unwrap();
|
||||
|
||||
let reply = lines
|
||||
.next_line()
|
||||
.await
|
||||
.expect("read line")
|
||||
.expect("got line");
|
||||
let resp: WorkerResponse = serde_json::from_str(&reply).expect("parse reply");
|
||||
match resp {
|
||||
WorkerResponse::Error { kind, .. } => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
// With cuda enabled the response depends on whether
|
||||
// CUDA hardware is actually present. Accept either
|
||||
// the success contract or a real NCCL failure.
|
||||
let _ = kind;
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
assert_eq!(kind, "cuda_feature_not_enabled");
|
||||
}
|
||||
WorkerResponse::InitOk => {
|
||||
// Real NCCL succeeded — only possible with cuda feature
|
||||
// AND a working NCCL stack AND another rank actually
|
||||
// joining. Don't fail; just acknowledge.
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
panic!("InitOk without cuda feature is impossible");
|
||||
}
|
||||
other => panic!("expected Error or InitOk, got {other:?}"),
|
||||
}
|
||||
|
||||
// Clean shutdown.
|
||||
stdin.write_all(b"{\"op\":\"shutdown\"}\n").await.unwrap();
|
||||
stdin.flush().await.unwrap();
|
||||
let _ = lines.next_line().await; // Bye
|
||||
let _ = child.wait().await;
|
||||
}
|
||||
45
crates/neuron/tests/tp_worker_lifecycle_cuda.rs
Normal file
45
crates/neuron/tests/tp_worker_lifecycle_cuda.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
//! Stage 7a-ii: real NCCL handshake across the worker pool.
|
||||
//!
|
||||
//! Gated behind the `cuda-integration` feature because it requires
|
||||
//! libnccl AND multiple CUDA devices on the running host. Run on
|
||||
//! beast (2× RTX 5090) via:
|
||||
//!
|
||||
//! cargo test -p neuron --features cuda-integration \
|
||||
//! --test tp_worker_lifecycle_cuda
|
||||
//!
|
||||
//! Steps: spawn N-1 workers, call `init_nccl`, run `nccl_sanity_check`
|
||||
//! (every rank `all_reduce`s `1u32` with Sum; expected total =
|
||||
//! world_size), shut down cleanly.
|
||||
|
||||
#![cfg(feature = "cuda-integration")]
|
||||
|
||||
use neuron::harness::tp::WorkerPool;
|
||||
|
||||
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_init_and_sanity_check_two_ranks() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.with_env_filter("info,neuron=debug")
|
||||
.try_init();
|
||||
|
||||
// 2 ranks: leader = rank 0 on device 0, worker = rank 1 on device 1.
|
||||
let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(0)
|
||||
.expect("spawn leader device worker");
|
||||
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker)
|
||||
.await
|
||||
.expect("spawn worker pool");
|
||||
|
||||
pool.ping_all().await.expect("pong all workers");
|
||||
|
||||
pool.init_nccl(0)
|
||||
.await
|
||||
.expect("init_nccl: NCCL handshake across all ranks");
|
||||
|
||||
pool.nccl_sanity_check()
|
||||
.await
|
||||
.expect("nccl_sanity_check: observed_sum == world_size on all ranks");
|
||||
|
||||
pool.shutdown().await.expect("clean shutdown");
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user