1 Commits

Author SHA1 Message Date
123f692203 fix(rpm): drop %attr(,,user) on config files to avoid dnf silent filter
Some checks failed
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
CI / Format, lint, build, test (push) Has been cancelled
Using %attr(,,cortex) / %attr(,,neuron) on config files caused rpm's
auto-dep-generator to emit Requires: user(name) and group(name) on
each package. When those Requires couldn't be resolved — whether due
to sysusers Provides mismatches, missing GPG keys, or dnf5 cache
state — dnf5 silently filtered the package out of the candidate set
and reported "Nothing to do" rather than an unsatisfied-dep error.

Adopt the pattern that already works reliably across our infra
(grenade/monsoon): ship config files as default root:root with 0644
perms, don't declare user/group ownership in the rpm file list.
systemd-sysusers still creates the service user via the shipped
sysusers.d file; the service drops to that user at runtime via the
User= directive in the unit.

This removes the user(cortex)/user(neuron) Requires entirely, which
is the root cause of the dnf5 filtering. File permission tightening
can be reintroduced later — either via a separate secrets file with
different mode bits, or by moving secret material to /var/lib/<svc>/
where the service drop-privileges account already has write access.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 14:33:08 +03:00
115 changed files with 629 additions and 35070 deletions

View File

@@ -1,343 +0,0 @@
name: build-prerelease
# Manually-dispatched workflow that builds CUDA-flavoured neuron binaries
# (and a single cortex binary), packages each as a Fedora RPM, signs
# them, and publishes to the `unstable` channel at rpm.lair.cafe.
#
# Trigger from the Gitea UI: Actions → build-prerelease → Run workflow.
# Optionally provide a `ref` to build from a non-default branch.
#
# The published packages are versioned as e.g.
# helexa-neuron-blackwell-0.1.16-0.1.20260518T140530.gitabcdef0.fc43.x86_64
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
# commit time (s) commit sha
# so they sort BELOW the eventual 0.1.16-1 stable release, and so two
# commits on the same day are still strictly ordered by their commit
# timestamps (rather than by RPM-vercmp's alpha-vs-digit precedence
# on the SHA fragment).
on:
# Auto-build on every push to main so the unstable channel tracks
# head without a manual dispatch step.
push:
branches: [main]
# Manual dispatch still available to build from a non-main ref.
workflow_dispatch:
inputs:
ref:
description: "Git ref to build (branch / tag / commit). Defaults to the workflow's branch."
required: false
default: ""
concurrency:
# Share the group with ci.yml so the two workflows can't run
# concurrently on the same `rust` runner (act reuses the workspace
# cache and races destroy each other's build files mid-compile).
# cancel-in-progress=false → workflows queue; if a newer push lands,
# the older run is still picked up by ci.yml's own ref-keyed
# concurrency (same group, queued).
group: cortex-runner-pool-${{ github.ref }}
cancel-in-progress: false
env:
CARGO_INCREMENTAL: "0"
CARGO_TERM_COLOR: "always"
jobs:
prepare:
name: Resolve version stamps
runs-on: rust
outputs:
version: ${{ steps.info.outputs.version }}
release: ${{ steps.info.outputs.release }}
short_sha: ${{ steps.info.outputs.short_sha }}
commit_timestamp: ${{ steps.info.outputs.commit_timestamp }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
fetch-depth: 0
- id: info
run: |
set -eux
VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml)
SHORT_SHA=$(git rev-parse --short=7 HEAD)
# Second-precise commit timestamp gives the release stamp a
# strictly monotonic numeric prefix. The earlier %Y%m%d-only
# form let same-day builds be ordered by RPM's rpmvercmp
# rules over the SHA, which is non-chronological — e.g.
# "git602e8e1" sorts newer than "gitf9f5fa4" purely because
# rpmvercmp ranks digit-prefixed segments above alpha ones.
# The SHA stays only as a debug identifier; sort order is
# decided entirely by the timestamp.
COMMIT_TIMESTAMP=$(git log -1 --format=%cd --date=format:%Y%m%d%H%M%S HEAD)
RELEASE="0.1.${COMMIT_TIMESTAMP}.git${SHORT_SHA}"
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
echo "release=${RELEASE}" >> "$GITHUB_OUTPUT"
echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT"
echo "commit_timestamp=${COMMIT_TIMESTAMP}" >> "$GITHUB_OUTPUT"
build-cortex:
name: Build cortex binary
needs: prepare
# runner-rust image already provides rust/cargo/clippy/rustfmt via
# dnf — no rustup install step needed.
runs-on: rust
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Build cortex (release)
run: cargo build --release -p cortex-cli
- name: Stage binary
run: |
mkdir --parents artifacts
cp target/release/cortex artifacts/cortex
./artifacts/cortex --version || true
- uses: actions/upload-artifact@v3
with:
name: cortex-fc43
path: artifacts/cortex
retention-days: 1
build-neuron:
name: Build neuron-${{ matrix.flavour }}
needs: prepare
strategy:
fail-fast: false
matrix:
include:
- flavour: ampere
compute_cap: "86"
runner: cuda-13.0
cuda_home: /usr/local/cuda-13.0
build_jobs: 8
nvcc_threads: 4
cargo_features: "cuda cudnn flash-attn"
- flavour: ada
compute_cap: "89"
runner: cuda-13.0
cuda_home: /usr/local/cuda-13.0
build_jobs: 8
nvcc_threads: 4
cargo_features: "cuda cudnn flash-attn"
- flavour: blackwell
compute_cap: "120"
runner: cuda-13.0
cuda_home: /usr/local/cuda-13.0
build_jobs: 8
nvcc_threads: 4
cargo_features: "cuda cudnn flash-attn"
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Build neuron with CUDA (${{ matrix.flavour }})
run: |
set -eux
export PATH="${{ matrix.cuda_home }}/bin:${PATH}"
export LD_LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LD_LIBRARY_PATH:-}"
export LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LIBRARY_PATH:-}"
cargo build --release -p neuron --features "${{ matrix.cargo_features }}"
env:
CUDA_COMPUTE_CAP: ${{ matrix.compute_cap }}
CARGO_BUILD_JOBS: ${{ matrix.build_jobs }}
NVCC_THREADS: ${{ matrix.nvcc_threads }}
- name: Stage binary
run: |
mkdir --parents artifacts
cp target/release/neuron artifacts/neuron-${{ matrix.flavour }}
file "artifacts/neuron-${{ matrix.flavour }}"
- uses: actions/upload-artifact@v3
with:
name: neuron-${{ matrix.flavour }}-fc43
path: artifacts/neuron-${{ matrix.flavour }}
retention-days: 1
package-cortex:
name: Package cortex RPM
needs: [prepare, build-cortex]
runs-on: rpm
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- uses: actions/download-artifact@v3
with:
name: cortex-fc43
path: artifacts/
- name: Build RPM
run: |
set -eux
rm -f ~/.rpmmacros
rpmdev-setuptree
cp artifacts/cortex ~/rpmbuild/SOURCES/
cp data/cortex.service ~/rpmbuild/SOURCES/
cp data/cortex-sysusers.conf ~/rpmbuild/SOURCES/
cp data/cortex-firewalld.xml ~/rpmbuild/SOURCES/
cp cortex.example.toml ~/rpmbuild/SOURCES/
cp models.example.toml ~/rpmbuild/SOURCES/
cp LICENSE ~/rpmbuild/SOURCES/
rpmbuild -bb rpm/cortex-prerelease.spec \
--define "cortex_version ${{ needs.prepare.outputs.version }}" \
--define "cortex_prerelease ${{ needs.prepare.outputs.release }}" \
--undefine dist \
--define "dist .fc43"
- uses: actions/upload-artifact@v3
with:
name: rpm-cortex-fc43
path: ~/rpmbuild/RPMS/x86_64/*.rpm
retention-days: 7
package-neuron:
name: Package helexa-neuron-${{ matrix.flavour }} RPM
needs: [prepare, build-neuron]
runs-on: rpm
strategy:
fail-fast: false
matrix:
include:
- flavour: ampere
- flavour: ada
- flavour: blackwell
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- uses: actions/download-artifact@v3
with:
name: neuron-${{ matrix.flavour }}-fc43
path: artifacts/
- name: Build RPM
run: |
set -eux
rm -f ~/.rpmmacros
rpmdev-setuptree
cp artifacts/neuron-${{ matrix.flavour }} ~/rpmbuild/SOURCES/
cp data/neuron.service ~/rpmbuild/SOURCES/
cp data/neuron-sysusers.conf ~/rpmbuild/SOURCES/
cp data/neuron-firewalld.xml ~/rpmbuild/SOURCES/
cp neuron.example.toml ~/rpmbuild/SOURCES/
cp LICENSE ~/rpmbuild/SOURCES/
rpmbuild -bb rpm/helexa-neuron-prerelease.spec \
--define "neuron_version ${{ needs.prepare.outputs.version }}" \
--define "neuron_flavour ${{ matrix.flavour }}" \
--define "neuron_prerelease ${{ needs.prepare.outputs.release }}" \
--undefine dist \
--define "dist .fc43"
- uses: actions/upload-artifact@v3
with:
name: rpm-neuron-${{ matrix.flavour }}-fc43
path: ~/rpmbuild/RPMS/x86_64/*.rpm
retention-days: 7
publish:
name: Publish to rpm.lair.cafe (unstable)
needs: [package-cortex, package-neuron]
runs-on: rpm
concurrency:
group: rpm-publish
cancel-in-progress: false
env:
RPM_REPO_HOST: oolon.kosherinata.internal
FEDORA_VERSION: "43"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Download all built RPMs
uses: actions/download-artifact@v3
with:
path: rpms/
pattern: rpm-*-fc43
- name: Flatten RPM artifacts
run: |
set -eux
find rpms/ -name '*.rpm' -exec mv --target-directory=rpms/ {} +
find rpms/ -mindepth 1 -type d -empty -delete
ls -la rpms/
- name: Check for sequoia-sq
run: |
if ! command -v sq &> /dev/null; then
echo "ERROR: sequoia-sq is not installed. Install with: sudo dnf install sequoia-sq"
exit 1
fi
- name: Import signing key
env:
# Pass secrets via env so values stay out of the rendered shell
# script (which Gitea includes in step logs). Template
# expansion of ${{ secrets.X }} inside `run:` writes the literal
# value into the script and depends on Gitea's log masker to
# scrub it — fragile for multi-line keys.
RPM_SIGNING_KEY: ${{ secrets.RPM_SIGNING_KEY }}
RPM_SIGNING_KEY_ID: ${{ secrets.RPM_SIGNING_KEY_ID }}
run: |
echo "$RPM_SIGNING_KEY" | gpg --batch --import
fpr=$(gpg --batch --with-colons --list-keys "$RPM_SIGNING_KEY_ID" | awk -F: '/^fpr:/ { print $10; exit }')
echo "${fpr}:6:" | gpg --batch --import-ownertrust
sed "s/@GPG_NAME@/$RPM_SIGNING_KEY_ID/" rpm/rpmmacros > ~/.rpmmacros
- name: Sign RPMs
run: |
set -eux
for rpm in rpms/*.rpm; do
echo "signing ${rpm}..."
rpm --addsign "${rpm}"
done
- name: Set up SSH for rsync
run: |
install --directory --mode 700 ~/.ssh
echo "${RSYNC_SSH_KEY}" | install --mode 600 /dev/stdin ~/.ssh/id_ed25519
env:
RSYNC_SSH_KEY: ${{ secrets.RSYNC_SSH_KEY }}
- name: Test SSH connectivity
run: |
ssh -o StrictHostKeyChecking=accept-new "gitea_ci@${RPM_REPO_HOST}" exit
- name: Ensure unstable repo directory exists
run: |
ssh "gitea_ci@${RPM_REPO_HOST}" \
"mkdir --parents /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable"
- name: Sync RPMs to unstable repo
run: |
rsync \
--archive \
--verbose \
--chmod D755,F644 \
rpms/*.rpm \
"gitea_ci@${RPM_REPO_HOST}:/var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/"
- name: Update unstable repo metadata
run: |
ssh "gitea_ci@${RPM_REPO_HOST}" \
"cd /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable && createrepo_c --update ."
- name: Generate packages.json manifest
run: |
scp script/generate-packages-json.py "gitea_ci@${RPM_REPO_HOST}:/tmp/"
ssh "gitea_ci@${RPM_REPO_HOST}" \
"python3 /tmp/generate-packages-json.py \
--repodata-dir /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/repodata \
--output /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/packages.json \
--base-url https://rpm.lair.cafe/fedora/${FEDORA_VERSION}/x86_64/unstable"

View File

@@ -7,16 +7,6 @@ on:
pull_request:
branches: [main]
# Share a concurrency group with build-prerelease.yml so the two
# workflows don't race on the same `rust` runner workspace (act's
# /root/.cache/act/<hash>/hostexecutor/ is shared across concurrent
# jobs and one job's checkout step nukes another's in-flight build
# files). cancel-in-progress=false → they queue; same-ref pushes
# coalesce per workflow via cancel-in-progress on each.
concurrency:
group: cortex-runner-pool-${{ github.ref }}
cancel-in-progress: false
env:
CARGO_INCREMENTAL: "0"
RUSTC_WRAPPER: sccache
@@ -26,139 +16,56 @@ env:
SCCACHE_S3_USE_SSL: "false"
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
# fmt, clippy, and test all run in parallel on the same `rust` runner
# and would otherwise share /root/.cache/act/<hash>/hostexecutor/target/,
# racing each other's cargo temp files (.tmpXXXXXX) and failing builds
# mid-compile. Give each job its own target directory so the invocations
# don't collide. sccache still backs the actual rustc cache, so the
# rebuild penalty is small.
CARGO_TARGET_DIR: target-${{ github.job }}
jobs:
fmt:
name: Format
runs-on: rust
check:
name: Format, lint, build, test
runs-on: fedora
steps:
- uses: actions/checkout@v4
- run: cargo fmt --check --all
clippy:
name: Clippy
runs-on: rust
steps:
- uses: actions/checkout@v4
# sccache occasionally fails with spurious race-condition errors;
# retrying the same invocation succeeds without code changes.
# Allow up to 3 attempts before declaring real failure.
- name: Clippy (with retry)
run: |
for attempt in 1 2 3; do
echo "::group::clippy attempt ${attempt}"
if cargo clippy --workspace -- -D warnings; then
echo "::endgroup::"
exit 0
fi
echo "::endgroup::"
echo "clippy failed on attempt ${attempt}"
if [ "${attempt}" -lt 3 ]; then
sleep 5
fi
done
echo "clippy failed after 3 attempts"
exit 1
- run: sccache --show-stats
- name: Cache cargo registry and target
uses: actions/cache@v4
with:
path: |
~/.cargo/bin
~/.cargo/registry/index
~/.cargo/registry/cache
~/.cargo/git/db
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
test:
name: Test
runs-on: rust
steps:
- uses: actions/checkout@v4
# See the clippy job for why this is retried.
- name: Test (with retry)
- name: Ensure sccache with S3 support
env:
RUSTC_WRAPPER: ""
run: |
for attempt in 1 2 3; do
echo "::group::test attempt ${attempt}"
if cargo test --workspace; then
echo "::endgroup::"
exit 0
fi
echo "::endgroup::"
echo "test failed on attempt ${attempt}"
if [ "${attempt}" -lt 3 ]; then
sleep 5
fi
done
echo "test failed after 3 attempts"
exit 1
- run: sccache --show-stats
if sccache --version 2>/dev/null && sccache --show-stats 2>/dev/null; then
echo "sccache with S3 support already installed"
else
cargo install sccache --features s3 --locked
fi
# Type-check the CUDA-only code path. Borrow-check-only — we
# never run the tests here (the runner has no GPU). This catches
# the category of bug where a refactor compiles fine under the
# default feature set (which is what the `clippy` and `test` jobs
# exercise) but fails inside a `#[cfg(feature = "cuda")]` block.
# `runs-on: cuda-13.0` selects the runner that ships nvcc /
# cudarc's build prerequisites. The generic `rust` and `rpm`
# runners don't have them (the previous label `rpm` was tried
# first and tripped cudarc's `nvcc --version` build script —
# see commit history).
cuda-check:
name: CUDA type-check
runs-on: cuda-13.0
# The workflow-level env sets `RUSTC_WRAPPER: sccache` for the
# `rust` runner (where fmt/clippy/test live and sccache is
# installed). The `cuda-13.0` runner doesn't have sccache on
# PATH, so inheriting the wrapper makes cargo bail with
# `could not execute process `sccache rustc -vV` (never executed)`
# before borrow-check even starts. Clear it locally. Also clear
# SCCACHE_* so cargo doesn't try to contact the cache (the
# remote auth headers come from secrets that aren't present on
# this runner either). Lose the cache, keep the gate.
env:
RUSTC_WRAPPER: ""
SCCACHE_BUCKET: ""
SCCACHE_ENDPOINT: ""
SCCACHE_REGION: ""
SCCACHE_S3_USE_SSL: ""
AWS_ACCESS_KEY_ID: ""
AWS_SECRET_ACCESS_KEY: ""
steps:
- uses: actions/checkout@v4
- name: cargo check --features cuda (with retry)
run: |
# act launches the step shell without /etc/profile, so the
# gitea_runner user's inherited PATH lacks /usr/local/cuda-13.0/bin.
# cudarc's build.rs:157 shells out to `nvcc --version` (because
# the neuron crate enables cuda-version-from-build-system) and
# panics with ENOENT if nvcc isn't resolvable. build-prerelease.yml
# does the same export — keep them in sync.
export PATH="/usr/local/cuda-13.0/bin:${PATH}"
export LD_LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LD_LIBRARY_PATH:-}"
export LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LIBRARY_PATH:-}"
for attempt in 1 2 3; do
echo "::group::cuda-check attempt ${attempt}"
if cargo check -p neuron --features cuda --all-targets; then
echo "::endgroup::"
exit 0
fi
echo "::endgroup::"
echo "cuda-check failed on attempt ${attempt}"
if [ "${attempt}" -lt 3 ]; then
sleep 5
fi
done
echo "cuda-check failed after 3 attempts"
exit 1
- name: Check formatting
run: cargo fmt --check --all
- name: Clippy
run: cargo clippy --workspace -- -D warnings
- name: Test
run: cargo test --workspace
- name: Show sccache stats
run: sccache --show-stats
srpm-cortex:
name: Build cortex SRPM
runs-on: rpm
needs: [fmt, clippy, test, cuda-check]
runs-on: fedora
needs: check
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Determine version
id: version
@@ -172,12 +79,6 @@ jobs:
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
- name: Generate changelog entry
uses: https://git.lair.cafe/actions/rpm-changelog@v1
with:
spec: cortex.spec
version: ${{ steps.version.outputs.VERSION }}
- name: Generate source tarball
run: |
set -ex
@@ -212,13 +113,11 @@ jobs:
srpm-neuron:
name: Build neuron SRPM
runs-on: rpm
needs: [fmt, clippy, test, cuda-check]
runs-on: fedora
needs: check
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Determine version
id: version
@@ -230,37 +129,31 @@ jobs:
run: |
VERSION="${{ steps.version.outputs.VERSION }}"
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
- name: Generate changelog entry
uses: https://git.lair.cafe/actions/rpm-changelog@v1
with:
spec: helexa-neuron.spec
version: ${{ steps.version.outputs.VERSION }}
sed -i "s/^Version:.*/Version: ${VERSION}/" neuron.spec
- name: Generate source tarball
run: |
set -ex
VERSION="${{ steps.version.outputs.VERSION }}"
tar czf /tmp/helexa-neuron-${VERSION}.tar.gz \
--transform "s,^\.,helexa-neuron-${VERSION}," \
tar czf /tmp/neuron-${VERSION}.tar.gz \
--transform "s,^\.,neuron-${VERSION}," \
--exclude='./target' \
--exclude='./.git' \
--exclude='*.tar.gz' \
--exclude='*.src.rpm' \
.
mv /tmp/helexa-neuron-${VERSION}.tar.gz .
mv /tmp/neuron-${VERSION}.tar.gz .
- name: Vendor Rust dependencies
run: |
VERSION="${{ steps.version.outputs.VERSION }}"
cargo vendor vendor/
tar czf helexa-neuron-${VERSION}-vendor.tar.gz vendor/
tar czf neuron-${VERSION}-vendor.tar.gz vendor/
rm -rf vendor/
- name: Build SRPM
run: |
rpmbuild -bs helexa-neuron.spec \
rpmbuild -bs neuron.spec \
--define "_sourcedir $(pwd)" \
--define "_srcrpmdir $(pwd)"
@@ -272,7 +165,7 @@ jobs:
copr-cortex:
name: Publish cortex to COPR
runs-on: fedora-43
runs-on: fedora
needs: srpm-cortex
steps:
- name: Download SRPM
@@ -283,13 +176,13 @@ jobs:
- name: Publish to COPR
uses: https://git.lair.cafe/actions/copr-publish@v1
with:
project: helexa/helexa
project: helexa/cortex
srpm: "*.src.rpm"
copr-config: ${{ secrets.COPR_CONFIG }}
copr-neuron:
name: Publish neuron to COPR
runs-on: fedora-43
runs-on: fedora
needs: srpm-neuron
steps:
- name: Download SRPM
@@ -300,53 +193,31 @@ jobs:
- name: Publish to COPR
uses: https://git.lair.cafe/actions/copr-publish@v1
with:
project: helexa/helexa
project: helexa/neuron
srpm: "*.src.rpm"
copr-config: ${{ secrets.COPR_CONFIG }}
bump-version:
name: Bump version in source
runs-on: rust
runs-on: fedora
needs: [copr-cortex, copr-neuron]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Determine version
id: version
run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> "$GITHUB_OUTPUT"
- name: Stamp version
run: |
VERSION="${{ steps.version.outputs.VERSION }}"
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
sed -i "s/^Version:.*/Version: ${VERSION}/" helexa-neuron.spec
cargo check --workspace 2>/dev/null || true
- name: Generate cortex changelog entry
uses: https://git.lair.cafe/actions/rpm-changelog@v1
with:
spec: cortex.spec
version: ${{ steps.version.outputs.VERSION }}
- name: Generate helexa-neuron changelog entry
uses: https://git.lair.cafe/actions/rpm-changelog@v1
with:
spec: helexa-neuron.spec
version: ${{ steps.version.outputs.VERSION }}
- name: Commit and push
- name: Stamp version and push
env:
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
run: |
VERSION="${{ steps.version.outputs.VERSION }}"
VERSION="${GITHUB_REF#refs/tags/v}"
sed -i '/\[workspace\.package\]/,/\[/{ s/^version = ".*"/version = "'"${VERSION}"'"/ }' Cargo.toml
sed -i "s/^Version:.*/Version: ${VERSION}/" cortex.spec
sed -i "s/^Version:.*/Version: ${VERSION}/" neuron.spec
cargo check --workspace 2>/dev/null || true
git config user.name "Gitea Actions"
git config user.email "actions@git.lair.cafe"
git add Cargo.toml Cargo.lock cortex.spec helexa-neuron.spec
git add Cargo.toml Cargo.lock cortex.spec neuron.spec
if git diff --cached --quiet; then
echo "Nothing to commit for ${VERSION}"
echo "Version already at ${VERSION}"
else
git commit -m "chore: bump version to ${VERSION}"
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/helexa/cortex.git"

2
.gitignore vendored
View File

@@ -4,6 +4,4 @@
.idea/
.vscode/
cortex.toml
models.toml
doc/plan/*
/target-cuda/

212
CLAUDE.md
View File

@@ -84,63 +84,6 @@ Per-request: model, node, prompt_tokens, completion_tokens, total_tokens,
tok_per_sec, time_to_first_token_ms, total_latency_ms.
Exposed as Prometheus histograms/counters on a separate port.
### Per-device worker thread (neuron)
The neuron daemon dedicates one OS thread per CUDA device it loads
onto. That thread binds the device's `CudaContext` once at startup and
owns it for the daemon's lifetime; every model load, forward step,
KV-cache reset, VRAM query, NCCL init/sanity, NCCL all_reduce, and
model drop on that device routes through this thread via a
`std::sync::mpsc` job channel. Replies cross back via
`tokio::sync::oneshot`.
Three properties this gives us, in order of weight:
1. **Context locality.** cudarc binds the CUDA context per OS thread
via `cuCtxSetCurrent`. Before this refactor, ad-hoc
`tokio::task::spawn_blocking` calls bound the context onto a
different thread per request — and `device_vram_mb()` from an
async task bound it onto whichever tokio worker happened to be
running. Pinning the context to one named thread ends that.
2. **Drop safety.** Every `CudaSlice` in a `Tensor`, every
`cudarc::nccl::Comm`, and the `CudaContext` itself call `cuMemFree` /
`ncclCommDestroy` / `cuCtxDestroy` during `Drop` — and require the
right context current. With the worker owning the model slab,
`Drop` always runs on the right thread. The cudarc Drop constraint
is structurally enforced.
3. **Poisoning blast radius.** When a CUDA driver error makes the
context unrecoverable, the poison flag lives on the
`DeviceWorkerHandle` itself. Subsequent `submit()` calls fast-reject
at the channel boundary with a clear "device worker is poisoned"
error before any further CUDA work is attempted. The thread doesn't
exit (dropping the slab would re-touch the broken context) — it
enters a drain-only mode and replies error to everything until the
daemon restarts.
Tensors never escape the worker thread alive. Inference replies carry
`Vec<f32>` CPU-side logits; the async caller wraps them in a CPU
candle tensor and runs `apply_repeat_penalty` + `LogitsProcessor::sample`
without ever rebinding the device context. Sampled tokens come back as
`u32`; VRAM queries as `(u64, u64)`. The opaque `ArchHandle(u64)` and
`TpHandle(u64)` are the only "references" callers hold to loaded
models — they're indices into the worker's state slab, not pointers.
The TP worker subprocesses in `harness/tp/worker.rs` are the same
pattern out-of-process — a dedicated context-owning process per
non-zero NCCL rank. The in-process worker in `harness/device_worker/`
brings the discipline to rank 0.
CPU loads (`Device::Cpu` fallback when CUDA is unavailable) keep the
legacy `tokio::task::spawn_blocking + Arc<Mutex<ModelArch>>` path —
there's no context to own and the channel hop would only add latency.
Four `spawn_blocking` references in `harness/candle.rs` are deliberate
CPU fallback.
Canonical narrative lives in
`crates/neuron/src/harness/device_worker/mod.rs`'s module
doc-comment; touch points (the `Job` enum, the dispatch handlers, the
`DeviceWorkerState` struct) are in the sibling `jobs.rs` and
`dispatch.rs`.
## Tech stack
- **Rust 2024 edition** — workspace with 4 crates
@@ -182,8 +125,7 @@ automatically. Clippy warnings must be resolved, not suppressed with
- One or more GPU nodes running mistral.rs on port 8080
- Optionally a metrics-only node (no GPU) for Prometheus/Grafana
- Each node runs `mistralrs serve` on port 8080
- Gateway listens on port 31313 (API) and 31314 (metrics)
- neuron listens on port 13131 on each GPU host
- Gateway listens on port 8000 (API) and 9100 (metrics)
- TLS terminated at gateway or via nginx; internal traffic is plaintext over WireGuard
## Conventions
@@ -438,7 +380,7 @@ processes (one process per loaded model, each on its own port).
## neuron API
neuron exposes an HTTP API on port 13131 that cortex polls and calls.
neuron exposes an HTTP API on port 9090 that cortex polls and calls.
```
GET /discovery
@@ -482,8 +424,8 @@ endpoint. cortex.toml shrinks to:
```toml
[gateway]
listen = "0.0.0.0:31313"
metrics_listen = "0.0.0.0:31314"
listen = "0.0.0.0:8000"
metrics_listen = "0.0.0.0:9100"
[eviction]
strategy = "lru"
@@ -491,15 +433,15 @@ defrag_after_cycles = 50
[[neurons]]
name = "beast"
endpoint = "http://beast.hanzalova.internal:13131"
endpoint = "http://beast.hanzalova.internal:9090"
[[neurons]]
name = "benjy"
endpoint = "http://benjy.hanzalova.internal:13131"
endpoint = "http://benjy.kosherinata.internal:9090"
[[neurons]]
name = "quadbrat"
endpoint = "http://quadbrat.hanzalova.internal:13131"
endpoint = "http://quadbrat.hanzalova.internal:9090"
```
On startup and periodically, cortex calls `GET /discovery` and
@@ -579,7 +521,7 @@ cortex/
│ │ └── metrics.rs # prometheus exporter (unchanged)
│ ├── neuron/ # node plane (replaces cortex-agent)
│ │ └── src/
│ │ ├── main.rs # binary entrypoint, axum server on :13131
│ │ ├── main.rs # binary entrypoint, axum server on :9090
│ │ ├── discovery.rs # nvidia-smi, device enumeration
│ │ ├── health.rs # runtime GPU polling
│ │ ├── api.rs # HTTP handlers for /discovery, /models, etc.
@@ -653,104 +595,70 @@ placement matching can be added incrementally.
Completed. Both packages have RPM specs, systemd units, and example configs.
CI builds parallel SRPMs on tag push and publishes to separate COPR repos.
- `cortex.spec` — installs the `cortex` binary. Package name keeps the
short `cortex` because no Fedora package collides with it.
- `helexa-neuron.spec` — installs the `neuron` binary under package name
`helexa-neuron`. Renamed from bare `neuron` to avoid collision with
Fedora's NEURON neural-simulation package
(https://src.fedoraproject.org/rpms/neuron); binary, systemd unit,
system user, and config dir all stay named `neuron` since those are
project-local contexts.
- `cortex.spec` `helexa/cortex` COPR: binary, systemd unit, config files
- `neuron.spec``helexa/neuron` COPR: binary, systemd unit, config
- `data/cortex.service`, `data/neuron.service` — systemd units
- `cortex.example.toml`, `neuron.example.toml`, `models.example.toml`
- CI: parallel `srpm-cortex` + `srpm-neuron` jobs, then parallel COPR
publish to a single project `helexa/helexa` hosting both packages.
- CI: parallel `srpm-cortex` + `srpm-neuron` jobs, then parallel COPR publish
Install:
```sh
dnf copr enable helexa/helexa
dnf install cortex # gateway host
dnf install helexa-neuron # GPU nodes
dnf copr enable helexa/cortex && dnf install cortex # gateway host
dnf copr enable helexa/neuron && dnf install neuron # GPU nodes
```
## 2026-05-18 addendum: candle-native pivot
### Phase 11: llama.cpp harness stub
Phases 11 (llama.cpp harness) and 12 (mistral.rs COPR) below are
**superseded**. The project no longer treats mistral.rs or llama.cpp as
dependencies — both are conceptually out of scope. neuron becomes a
candle-native inference daemon, with `Harness` retained as an
internal seam for adding future engines (vision/audio/diffusion) but
its only implementation being in-process candle.
**Goal:** Prove the harness abstraction works with a second engine.
The full staged plan for this pivot lives at
`~/.claude/plans/create-a-more-aggressive-calm-naur.md`. Summary:
**Steps:**
1. `crates/neuron/src/harness/llamacpp.rs` — implement the `Harness`
trait for llama.cpp's `llama-server`.
- `start()` — launch `llama-server` with the correct model path,
`--port`, `--n-gpu-layers`, `--tensor-split` args. Track the
child process.
- `stop()` — send SIGTERM to the child process.
- `list_models()` — llama-server serves one model per process, so
return a single-element list.
- `load_model()` — start a new llama-server process for this model.
- `unload_model()` — stop the process.
- `inference_endpoint()` — return `http://localhost:{assigned_port}`.
2. Port allocation: neuron assigns ports from a range (e.g. 8100-8199)
to llama-server instances.
3. Register in `HarnessRegistry` when configured:
```toml
[[harnesses]]
name = "llamacpp"
binary = "/usr/local/bin/llama-server"
port_range = [8100, 8199]
```
4. Tests: mock llama-server (simple HTTP server returning canned
responses), test load/unload/endpoint lifecycle.
- **Stage 1 (this commit):** delete `mistralrs.rs` and `llamacpp.rs`,
scaffold inert `CandleHarness`, drop `endpoint`/`systemd_unit` from
`HarnessConfig`, default no-op `start`/`stop` on the `Harness` trait.
- **Stages 24:** wire up candle model load/unload (quantized Qwen3
first), add OpenAI-compatible inference endpoint in neuron, then SSE
streaming.
- **Stages 56:** load-on-activation (default models in config) and
unload-on-deactivation (graceful shutdown).
- **Stages 78:** multi-GPU tensor parallelism and broader model/quant
coverage.
**Done when:** A model with `harness = "llamacpp"` in `models.toml` can
be loaded and served through cortex. Tests pass with mock llama-server.
Sections of this document that describe mistral.rs HTTP behaviour
("mistral.rs API gotchas") are retained as historical context for
Phases 110 — they document what was true while the project depended
on mistral.rs. They do not describe current behaviour.
### Phase 12 (lower priority): mistral.rs COPR packaging
---
**Goal:** Fedora RPMs for mistral.rs built against specific CUDA versions.
### Phase 11 (superseded): llama.cpp harness stub
**Steps:**
1. `mistralrs-cuda.spec` — RPM spec that clones a pinned mistral.rs git
tag, builds with `--features cuda`, links against the system CUDA
toolkit. Produces `mistralrs-cuda13-server` (CUDA 13.x / sm_120) and
`mistralrs-cuda12-server` (CUDA 12.x / sm_89). Install binary to
`/usr/local/bin/mistralrs`.
2. COPR build config: enable the NVIDIA CUDA repo as a build dependency.
Pin the CUDA toolkit version in `BuildRequires`.
3. Gitea Actions or manual workflow: bump the mistral.rs tag in the spec,
trigger COPR rebuild.
4. neuron's mistralrs harness config references which binary/package
provides the mistral.rs binary. neuron could warn at startup if the
installed mistral.rs CUDA version doesn't match the discovered driver.
~~Originally planned as a second engine to prove the harness
abstraction.~~ Replaced by the candle harness work in the 2026-05-18
addendum above. llama.cpp's any-model/any-hardware breadth is no
longer in scope for helexa.
**Done when:** `dnf install mistralrs-cuda13-server` on beast provides a
working `mistralrs` binary built for Blackwell GPUs. `dnf install
mistralrs-cuda12-server` on benjy provides one built for Ada GPUs.
### Phase 12 (superseded): mistral.rs COPR packaging
~~Originally planned to ship CUDA-versioned mistral.rs RPMs.~~ Replaced
by the candle harness work in the 2026-05-18 addendum above. With
mistral.rs out of the dependency tree, there is nothing to package.
## 2026-05-27 addendum: per-device worker thread
Replaced the ad-hoc `tokio::task::spawn_blocking` pattern that drove
every leader-side CUDA op with one dedicated OS thread per CUDA device,
permanently bound to that device's `CudaContext`. All leader-side
inference work (GGUF + dense + TP shard load, forward, kv-cache clear,
NCCL init/sanity, NCCL all_reduce, VRAM query, model drop) routes
through the worker via a `std::sync::mpsc` channel; tensors never
escape the worker thread alive. See "Per-device worker thread (neuron)"
above and `crates/neuron/src/harness/device_worker/mod.rs` for the
canonical narrative.
Motivated by the 2026-05-26 silent-hang on beast: a CUDA OOM cascade
poisoned the device context on whichever spawn_blocking thread caught
it, and subsequent requests stalled invisibly on the pool lock. After
the refactor, the same failure mode shows up in journalctl as
`prefill sample failed; logits unhealthy nan: 248320/248320` followed
by `failed, model marked poisoned`. The thread stays alive and rejects
subsequent requests at the channel boundary.
Landed in four PRs:
- **Phase 1** (`081b532`) — device_worker module + 8 VRAM-query sites
route through the worker. CPU build only; smoke on beast confirmed
a persistent `cuda-dev-0` thread.
- **Phase 2** (`b179204`) — single-GPU forward + clear_kv + drop via
the worker. `LoadedModel.arch_handle: Option<ArchHandle>` replaces
`Arc<Mutex<ModelArch>>` for CUDA loads. CPU keeps the legacy path.
- **Phase 3** (`76ab24d`) — TP forward + NCCL init/sanity + leader
KV-clear routed through the worker. `WorkerPool.leader_nccl` moves
into the worker's state. `TpLoadedModel.leader_handle: TpHandle`
replaces `Arc<Mutex<TpLeaderModel>>`. CUDA-only TP smoke deferred to
next deploy.
- **Phase 4** (`b4f3576`) — GGUF + dense + TP shard loads move onto
the worker. The `Job::TransferIn` / `Job::CloneLeaderComm` bridges
from Phases 2/3 deleted; `SendComm` newtype no longer needed in the
load path. `grep -rn spawn_blocking crates/neuron/src/harness/`
returns only deliberate CPU-fallback hits after this PR.
This is a separate repo/spec — not part of the cortex workspace — but
tightly coupled operationally. Track it as a sibling project.

2418
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -5,11 +5,10 @@ members = [
"crates/cortex-gateway",
"crates/cortex-cli",
"crates/neuron",
"crates/helexa-acp",
]
[workspace.package]
version = "0.1.16"
version = "0.1.7"
edition = "2024"
license = "GPL-3.0-or-later"
repository = "https://git.lair.cafe/helexa/cortex"
@@ -28,7 +27,7 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
# http client (for proxying to neuron backends)
# http client (for proxying to mistralrs backends)
reqwest = { version = "0.12", features = ["json", "stream"] }
# observability

113
README.md
View File

@@ -1,23 +1,22 @@
# cortex
A Rust reverse-proxy and fleet management layer for multi-node GPU inference
clusters. Cortex sits in front of one or more `neuron` daemons (each running
candle-based inference on a local GPU host) and presents a unified OpenAI +
Anthropic compatible API surface.
A Rust reverse-proxy and fleet management layer for multi-node
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
## Problem
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
model affinities) requires a unified API surface that:
- Presents a **single `/v1/models` catalogue** merging every model that can be
served by any neuron in the fleet.
- **Routes requests** to the correct node based on where a model is loaded
(or can be loaded), handling cold-load and eviction transparently.
- Manages **model lifecycle** load on demand, unload cold models, pin
critical ones — by calling each neuron's `/models/{load,unload}` API.
- Presents a **single `/v1/models` catalogue** merging every model across every
node.
- **Routes requests** to the correct node based on where a model is loaded (or
*can* be loaded).
- Manages **model lifecycle** — unload cold models, reload on demand, pin
critical ones — using the mistral.rs
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
- Translates between **OpenAI and Anthropic** request/response envelopes so
every client speaks whichever dialect it prefers.
every client in the homelab speaks whichever dialect it prefers.
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
them as Prometheus counters/histograms.
@@ -31,17 +30,18 @@ model affinities) requires a unified API surface that:
└────────────────┴──────┬───────┴───────────────┘
┌──────────▼──────────┐
cortex
│ (cortex-gateway) │
│ cortex │
(cortex-gateway)
│ │
│ Router · Metrics │
│ Evictor · Translate│
└──┬──────┬────────┬──┘
│ │ │
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
neuron │ │ neuron │ │ neuron
:13131 │ │ :13131 │ │ :13131
candle │ │ candle │ │ candle
gpu-large │ │gpu-med │ │ gpu-small
mistralrs │ │mistral │ │ mistralrs
serve │ │rs serve│ │ serve
│ :8080 │ │ :8080 │ │ :8080 │
└───────────┘ └────────┘ └───────────┘
private network (.internal)
```
@@ -50,58 +50,70 @@ model affinities) requires a unified API surface that:
| Crate | Purpose |
|---|---|
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic envelopes, harness trait, discovery types |
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, poller, metrics exporter |
| `neuron` | Per-node daemon: GPU discovery, in-process candle inference, model lifecycle API |
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
## Node setup
Each GPU node runs `neuron` (listening on `:13131`). Neuron uses
huggingface/candle for in-process inference — there is no external
inference subprocess to manage.
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
declared but start **unloaded** — mistral.rs lazy-loads on first request and
the gateway can explicitly unload/reload via the HTTP API.
Inside the daemon, every CUDA device gets one dedicated OS thread
(named `cuda-dev-N`) that owns the device's CUDA context for the
daemon's lifetime. Model loads, forward passes, KV-cache resets,
NCCL collectives, VRAM queries, and unloads all route through that
thread via a job channel; tensors never escape it alive. This pins
context binding to a known thread, makes the CUDA Drop contract
structurally safe, and isolates driver-error poisoning to one worker
rather than the whole process. See `CLAUDE.md` for the design
rationale and `crates/neuron/src/harness/device_worker/` for the code.
Example node systemd unit:
The neuron RPM (`helexa-neuron`) ships a systemd unit:
```ini
# /etc/systemd/system/mistralrs.service
[Unit]
Description=mistral.rs inference server
After=network-online.target
Wants=network-online.target
```sh
dnf copr enable helexa/helexa
dnf install helexa-neuron
systemctl enable --now neuron
[Service]
Type=simple
ExecStart=/usr/local/bin/mistralrs serve \
--from-config /etc/mistralrs/config.toml \
--port 8080
Restart=on-failure
RestartSec=5
Environment=CUDA_VISIBLE_DEVICES=0,1
[Install]
WantedBy=multi-user.target
```
## Gateway config
```toml
# /etc/cortex/cortex.toml
# cortex.toml
[gateway]
listen = "0.0.0.0:31313"
metrics_listen = "0.0.0.0:31314"
listen = "0.0.0.0:8000"
metrics_listen = "0.0.0.0:9100"
[eviction]
strategy = "lru" # lru | priority
defrag_after_cycles = 50
[[neurons]]
name = "beast"
endpoint = "http://beast.internal:13131"
[[nodes]]
name = "gpu-large"
endpoint = "http://gpu-large.internal:8080"
vram_mb = 49_152 # e.g. 2x RTX 4090
pinned = ["your-org/large-model"]
[[neurons]]
name = "benjy"
endpoint = "http://benjy.internal:13131"
[[nodes]]
name = "gpu-medium"
endpoint = "http://gpu-medium.internal:8080"
vram_mb = 24_576 # e.g. RTX 4090
pinned = ["your-org/medium-model"]
[[nodes]]
name = "gpu-small"
endpoint = "http://gpu-small.internal:8080"
vram_mb = 12_288 # e.g. RTX 3060
pinned = ["your-org/embedding-model"]
```
Model placement profiles live in `models.toml` — see `models.example.toml`.
## Building
```sh
@@ -119,20 +131,19 @@ cargo clippy --workspace -- -D warnings # warnings are errors
cargo test --workspace # all tests must pass
```
Tagged releases (`v*`) additionally build SRPMs for both `cortex` and
`helexa-neuron` and publish to COPR.
Tagged releases (`v*`) additionally build an SRPM and publish to COPR.
## Running
```sh
# start the gateway
cortex serve --config /etc/cortex/cortex.toml
cortex serve --config cortex.toml
# check fleet status
cortex status
# list all models across nodes
curl http://localhost:31313/v1/models
curl http://localhost:8000/v1/models
```
## License

View File

@@ -1,30 +0,0 @@
# Helexa fleet manifest.
#
# Drives rolling deploys via script/deploy.sh and serves as the source
# of truth for which hosts run cortex vs neuron, and which CUDA
# compute-capability flavour each neuron host needs.
#
# Flavour ↔ NVIDIA generation ↔ compute cap:
# ampere sm_86 (RTX 30 series — e.g. 3060)
# ada sm_89 (RTX 40 series — e.g. 4090)
# blackwell sm_120 (RTX 50 series — e.g. 5090)
#
# The flavour determines which RPM is installed on a given neuron host:
# helexa-neuron-<flavour>. Only one flavour may be installed at a time
# (the packages Conflict: with each other).
cortex:
host: hanzalova.internal
neurons:
- host: beast.hanzalova.internal
flavour: blackwell
gpu: "2x RTX 5090"
- host: benjy.hanzalova.internal
flavour: ada
gpu: "RTX 4090"
- host: quadbrat.hanzalova.internal
flavour: ampere
gpu: "RTX 3060"

View File

@@ -1,24 +0,0 @@
# neuron.toml for beast.hanzalova.internal
#
# 2x RTX 5090 (32 GB each) — TP-2 capable. Pre-warms Qwen3.6-27B with
# q5k ISQ across both GPUs at activation, matching the validate-neuron
# invocation: `validate-neuron.sh beast.hanzalova.internal
# Qwen/Qwen3.6-27B q5k 2`.
#
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml. Edits
# take effect on the next deploy.sh run (which stops + restarts the
# service so default_models is re-read at activation).
port = 13131
[[harnesses]]
name = "candle"
[harness.candle]
[[default_models]]
model_id = "Qwen/Qwen3.6-27B"
harness = "candle"
quant = "q6k"
tensor_parallel = 2
devices = [0, 1]

View File

@@ -1,19 +0,0 @@
# neuron.toml for benjy.hanzalova.internal
#
# 1x RTX 4090 (24 GB) — largest single-GPU host on the fleet. Pre-warms
# Qwen3-8B (bf16, ~18 GB), leaving ~6 GB for KV cache + activations on
# moderate-length contexts.
#
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml.
port = 13131
[[harnesses]]
name = "candle"
[harness.candle]
[[default_models]]
model_id = "Qwen/Qwen3-8B"
harness = "candle"
devices = [0]

View File

@@ -1,19 +0,0 @@
# neuron.toml for quadbrat.hanzalova.internal
#
# 1x RTX 3060 (12 GB) — small / quantised tier. Pre-warms Qwen3-1.7B
# (bf16, ~4 GB), leaving ~7 GB for KV cache so long contexts on a small
# model still have plenty of room.
#
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml.
port = 13131
[[harnesses]]
name = "candle"
[harness.candle]
[[default_models]]
model_id = "Qwen/Qwen3-1.7B"
harness = "candle"
devices = [0]

View File

@@ -3,22 +3,22 @@
# Copy to cortex.toml and adjust for your environment.
#
# Environment variable overrides use CORTEX_ prefix with __ separators:
# CORTEX_GATEWAY__LISTEN=0.0.0.0:31313
# CORTEX_GATEWAY__LISTEN=0.0.0.0:9000
[gateway]
listen = "0.0.0.0:31313"
metrics_listen = "0.0.0.0:31314"
listen = "0.0.0.0:8000"
metrics_listen = "0.0.0.0:9100"
[eviction]
strategy = "lru"
# Restart neurons after this many load/unload cycles to defragment VRAM.
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
# Set to 0 to disable.
defrag_after_cycles = 50
# -- Nodes ---------------------------------------------------------------
# Each [[nodes]] entry declares a neuron daemon in the fleet.
# Models are discovered by polling the neuron's /models endpoint.
# Pinned models (see models.toml) are never evicted.
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
# Models are discovered by polling the node's /v1/models endpoint.
# Pinned models are never evicted.
[[nodes]]
name = "gpu-large"

View File

@@ -1,5 +1,5 @@
Name: cortex
Version: 0.1.16
Version: 0.1.7
Release: 1%{?dist}
Summary: Inference gateway for multi-node GPU clusters
@@ -21,16 +21,6 @@ BuildRequires: systemd-rpm-macros
Requires(pre): shadow-utils
Requires: systemd
Requires: firewalld-filesystem
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
# from our .service file and emits Requires: user(cortex)/group(cortex).
# rpm's sysusers provides-generator emits the unversioned form for groups
# but only a versioned user(cortex) = <base64> for users with GECOS/home/
# shell. Provide the unversioned user(cortex) explicitly so dnf can resolve
# the auto-generated Requires. Without this, dnf5 silently filters the
# package and reports "Nothing to do".
Provides: user(cortex)
%description
Cortex is a Rust reverse-proxy that sits in front of multiple inference
@@ -57,7 +47,6 @@ cargo build --release -p cortex-cli
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
install -Dm644 data/cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
install -dm755 %{buildroot}%{_sysconfdir}/cortex
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
@@ -74,53 +63,16 @@ install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
%postun
%systemd_postun_with_restart cortex.service
%posttrans
# Migration: older cortex packages shipped the firewalld service as
# `helexa-cortex` and (in some build streams) with wrong port numbers
# (9301/9302/9304). Operators who enabled that legacy service in their
# zone end up with the wrong-port override taking precedence over the
# vendor `cortex.xml` now in /usr/lib/firewalld/services/. Clean up the
# stale /etc/ override here and migrate any zone bindings to the new
# service name.
if [ -f /etc/firewalld/services/helexa-cortex.xml ]; then
rm -f /etc/firewalld/services/helexa-cortex.xml
fi
if [ -x /usr/bin/firewall-cmd ] && /usr/bin/firewall-cmd --state >/dev/null 2>&1; then
# Drop the legacy service name from every zone where it was enabled
# and add the new `cortex` service in its place. Operators who never
# ran firewall-cmd against either name see no zone change.
for zone in $(/usr/bin/firewall-cmd --get-active-zones 2>/dev/null \
| awk '!/^[[:space:]]/ {print $1}'); do
if /usr/bin/firewall-cmd --permanent --zone="$zone" --query-service=helexa-cortex >/dev/null 2>&1; then
/usr/bin/firewall-cmd --permanent --zone="$zone" --remove-service=helexa-cortex >/dev/null 2>&1 || :
/usr/bin/firewall-cmd --permanent --zone="$zone" --add-service=cortex >/dev/null 2>&1 || :
fi
done
/usr/bin/firewall-cmd --reload >/dev/null 2>&1 || :
fi
:
%files
%license LICENSE
%doc README.md
%{_bindir}/cortex
%{_unitdir}/cortex.service
%{_sysusersdir}/cortex.conf
%{_prefix}/lib/firewalld/services/cortex.xml
%dir %{_sysconfdir}/cortex
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
%config(noreplace) %{_sysconfdir}/cortex/models.toml
%changelog
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
- chore: ignore local deploy script
- chore: move default ports out of common-collision ranges
- ci: drop actions/cache for cargo registry and target
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
- ci: publish both packages to a single helexa/helexa COPR project
- fix(rpm): rename neuron package to helexa-neuron
- ci: commit generated %changelog entries back to main
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
* Tue Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
- Initial package

View File

@@ -5,7 +5,7 @@ use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(name = "cortex")]
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
#[command(version)]
struct Cli {
#[command(subcommand)]
@@ -23,7 +23,7 @@ enum Commands {
/// Print the fleet status (models, nodes, health).
Status {
/// Gateway API endpoint to query.
#[arg(short, long, default_value = "http://localhost:31313")]
#[arg(short, long, default_value = "http://localhost:8000")]
endpoint: String,
},
}

View File

@@ -2,7 +2,7 @@
//!
//! These mirror the `/v1/messages` format used by the Anthropic API.
//! The gateway accepts these, translates to OpenAI format, proxies to
//! the inference backend (neuron), then translates the response back.
//! mistral.rs, then translates the response back.
use serde::{Deserialize, Serialize};
use serde_json::Value;

View File

@@ -1,8 +1,6 @@
//! Model catalogue — profiles describing how to serve each model.
use crate::discovery::DeviceInfo;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
/// A model serving profile loaded from models.toml.
@@ -35,14 +33,6 @@ fn default_min_devices() -> u32 {
pub struct ModelCatalogue {
#[serde(default)]
pub models: Vec<ModelProfile>,
/// Tier aliases — clients can send a request with `model: "helexa/small"`
/// and the gateway transparently rewrites + routes to the concrete
/// model id this maps to. Lets operators define latency/quality
/// tiers (`small`/`balanced`/`large`, `fast`/`thinking`, etc.)
/// without imposing knowledge of specific model ids on clients.
/// Loaded from the `[aliases]` table in models.toml.
#[serde(default)]
pub aliases: HashMap<String, String>,
}
impl ModelCatalogue {
@@ -74,138 +64,4 @@ impl ModelCatalogue {
.iter()
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
}
/// Find a profile by model id.
pub fn get(&self, model_id: &str) -> Option<&ModelProfile> {
self.models.iter().find(|p| p.id == model_id)
}
/// Resolve an alias to its concrete model id. Returns `id` verbatim
/// when it isn't an alias. Aliases never chain — operator config
/// is treated as flat — so this is a single lookup.
pub fn resolve_alias<'a>(&'a self, id: &'a str) -> &'a str {
self.aliases.get(id).map(String::as_str).unwrap_or(id)
}
}
impl ModelProfile {
/// True iff this profile's placement constraints can be satisfied
/// by the named neuron with the given device topology.
///
/// Constraints checked:
/// - `pinned_on`: non-empty → neuron must be on the list.
/// - `min_devices`: neuron must have at least this many devices.
/// - `min_device_vram_mb`: at least `min_devices` of the neuron's
/// devices must each meet this VRAM floor.
pub fn is_feasible_on(&self, neuron_name: &str, devices: &[DeviceInfo]) -> bool {
if !self.pinned_on.is_empty() && !self.pinned_on.iter().any(|n| n == neuron_name) {
return false;
}
if (devices.len() as u32) < self.min_devices {
return false;
}
if let Some(min_vram) = self.min_device_vram_mb {
let big_enough = devices
.iter()
.filter(|d| d.vram_total_mb >= min_vram)
.count() as u32;
if big_enough < self.min_devices {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::DeviceInfo;
fn device(idx: u32, vram_mb: u64) -> DeviceInfo {
DeviceInfo {
index: idx,
name: format!("DEV-{idx}"),
vram_total_mb: vram_mb,
compute_capability: "8.6".into(),
}
}
fn profile() -> ModelProfile {
ModelProfile {
id: "Qwen/Qwen3.6-27B".into(),
harness: "candle".into(),
quant: None,
vram_mb: Some(45_000),
min_devices: 2,
min_device_vram_mb: Some(24_000),
pinned_on: vec![],
}
}
#[test]
fn feasible_when_two_devices_meet_vram_floor() {
let p = profile();
let devices = [device(0, 32_000), device(1, 32_000)];
assert!(p.is_feasible_on("beast", &devices));
}
#[test]
fn infeasible_when_only_one_device() {
let p = profile();
let devices = [device(0, 64_000)];
assert!(!p.is_feasible_on("benjy", &devices));
}
#[test]
fn infeasible_when_one_device_underspec() {
let p = profile();
let devices = [device(0, 32_000), device(1, 12_000)];
assert!(!p.is_feasible_on("mixed", &devices));
}
#[test]
fn pinned_on_excludes_other_neurons() {
let mut p = profile();
p.pinned_on = vec!["beast".into()];
let devices = [device(0, 32_000), device(1, 32_000)];
assert!(p.is_feasible_on("beast", &devices));
assert!(!p.is_feasible_on("benjy", &devices));
}
#[test]
fn no_vram_floor_just_needs_min_devices() {
let mut p = profile();
p.min_device_vram_mb = None;
let devices = [device(0, 1_000), device(1, 1_000)];
assert!(p.is_feasible_on("anywhere", &devices));
}
#[test]
fn resolve_alias_returns_target_when_alias_present() {
let mut cat = ModelCatalogue::default();
cat.aliases
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
}
#[test]
fn resolve_alias_passes_through_when_not_an_alias() {
let mut cat = ModelCatalogue::default();
cat.aliases
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
assert_eq!(cat.resolve_alias("Qwen/Qwen3-8B"), "Qwen/Qwen3-8B");
}
#[test]
fn aliases_table_round_trips_through_toml() {
let src = r#"
[aliases]
"helexa/small" = "Qwen/Qwen3-1.7B"
"helexa/large" = "Qwen/Qwen3.6-27B"
"#;
let cat: ModelCatalogue = toml::from_str(src).expect("parse aliases table");
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
assert_eq!(cat.resolve_alias("helexa/large"), "Qwen/Qwen3.6-27B");
}
}

View File

@@ -22,9 +22,9 @@ fn default_models_path() -> String {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewaySettings {
/// Address to listen on for API requests (e.g. "0.0.0.0:31313")
/// Address to listen on for API requests (e.g. "0.0.0.0:8000")
pub listen: String,
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:31314")
/// Address to listen on for Prometheus metrics (e.g. "0.0.0.0:9100")
pub metrics_listen: String,
}
@@ -50,7 +50,7 @@ pub enum EvictionStrategy {
pub struct NeuronEndpoint {
/// Human-readable node name (e.g. "beast")
pub name: String,
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131")
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090")
pub endpoint: String,
}
@@ -70,8 +70,8 @@ impl Default for GatewayConfig {
fn default() -> Self {
Self {
gateway: GatewaySettings {
listen: "0.0.0.0:31313".into(),
metrics_listen: "0.0.0.0:31314".into(),
listen: "0.0.0.0:8000".into(),
metrics_listen: "0.0.0.0:9100".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,

View File

@@ -36,72 +36,8 @@ pub struct DeviceHealth {
/// Runtime health response from a neuron endpoint.
/// Returned by `GET /health`.
///
/// `activation` was added in 2026-05-26 to distinguish "process is up
/// and reachable" from "process is ready to serve traffic". A `Type=simple`
/// systemd unit reports `active` the moment the binary starts — but a
/// neuron whose `default_models` list takes minutes to materialise
/// won't bind its listener (or, in the new flow, won't have any models
/// loaded) until pre-warm completes. The new field is `#[serde(default)]`
/// so a pre-2026-05-26 gateway polling a new neuron — or vice versa —
/// keeps working.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthResponse {
pub uptime_secs: u64,
pub devices: Vec<DeviceHealth>,
#[serde(default)]
pub activation: ActivationStatus,
}
/// High-level activation state of the neuron daemon. The HTTP listener
/// is bound during both states; what differs is whether the configured
/// `default_models` have finished loading.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ActivationState {
/// At least one `default_models` entry is still loading. The
/// neuron's other endpoints work, but inference against
/// not-yet-loaded models will 404.
PreWarming,
/// Every `default_models` entry has either loaded or failed; the
/// neuron is steady-state. Subsequent on-demand loads via
/// `/models/load` don't flip back to PreWarming — that field
/// reflects the activation-time set only.
#[default]
Ready,
}
/// Per-model failure record surfaced in [`ActivationStatus::failed`].
/// The error string is the rendered anyhow chain at the time of the
/// failure; operators read it from `/health` to decide whether to
/// retry, edit the spec, or unload+reload.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreWarmFailure {
pub model_id: String,
pub error: String,
}
/// Activation-time progress snapshot. All four lists are populated by
/// the neuron's pre-warm task and read by the `/health` handler. The
/// snapshot is consistent: a model id appears in exactly one of
/// `pending`, `in_progress` (as `Option<String>`), `completed`, or
/// `failed` at any point in time.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ActivationStatus {
pub state: ActivationState,
/// Model ids queued but not yet started. Empty in `Ready` state.
#[serde(default)]
pub pending: Vec<String>,
/// Model id currently materialising. None when between models or
/// in `Ready` state.
#[serde(default)]
pub in_progress: Option<String>,
/// Model ids that finished loading successfully during this
/// activation. Cleared on process restart.
#[serde(default)]
pub completed: Vec<String>,
/// Model ids that failed during this activation, with the rendered
/// error chain. Cleared on process restart.
#[serde(default)]
pub failed: Vec<PreWarmFailure>,
}

View File

@@ -9,13 +9,13 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
/// Configuration for a harness instance on a neuron.
///
/// All current harnesses are in-process (candle); per-harness tuning
/// (cache paths, device policies, etc.) lives in dedicated config
/// blocks rather than on this struct.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HarnessConfig {
pub name: String,
/// Base URL of the harness (e.g. "http://localhost:8080" for mistral.rs).
pub endpoint: Option<String>,
/// Systemd unit name, if the harness is managed via systemd.
pub systemd_unit: Option<String>,
}
/// Health status of a harness process.
@@ -47,24 +47,16 @@ pub struct ModelInfo {
}
/// What an inference harness must do, from neuron's perspective.
///
/// All current harnesses are in-process — they share neuron's address
/// space and lifecycle. `start`/`stop` therefore default to no-ops; a
/// future process-supervising harness would override them.
#[async_trait]
pub trait Harness: Send + Sync {
/// Human-readable name (e.g. "candle").
/// Human-readable name (e.g. "mistralrs", "llamacpp", "comfyui").
fn name(&self) -> &str;
/// Start the harness. Default no-op for in-process harnesses.
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
Ok(())
}
/// Start the harness process if it is not already running.
async fn start(&self, config: &HarnessConfig) -> Result<()>;
/// Stop the harness. Default no-op for in-process harnesses.
async fn stop(&self) -> Result<()> {
Ok(())
}
/// Stop the harness process gracefully.
async fn stop(&self) -> Result<()>;
/// Health check. Returns the harness process status.
async fn health(&self) -> HarnessHealth;

View File

@@ -6,5 +6,4 @@ pub mod harness;
pub mod metrics;
pub mod node;
pub mod openai;
pub mod responses;
pub mod translate;

View File

@@ -1,4 +1,3 @@
use crate::discovery::{ActivationStatus, DiscoveryResponse};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@@ -7,25 +6,13 @@ use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct NodeState {
pub name: String,
/// Base URL of the neuron daemon (e.g. "http://beast.internal:13131").
/// Base URL of the neuron daemon (e.g. "http://beast.internal:9090").
pub endpoint: String,
pub healthy: bool,
pub models: HashMap<String, ModelEntry>,
/// Number of load/unload cycles since last process restart.
pub lifecycle_cycles: u32,
pub last_poll: Option<DateTime<Utc>>,
/// Result of the most recent successful `GET /discovery` against
/// this neuron. Cached forever once obtained — device topology is
/// invariant for a given neuron process. `None` until the first
/// successful poll. Used by the router and `/v1/models` to do
/// catalogue × topology feasibility checks.
pub discovery: Option<DiscoveryResponse>,
/// Last-seen pre-warm progress from this neuron's `/health`
/// endpoint. `None` until the first /health poll succeeds. The
/// `/v1/models` handler reads `in_progress` + `pending` from here
/// to synthesize `Loading` locations so clients see a catalogued
/// model that's mid-prewarm as "loading", not "missing".
pub activation: Option<ActivationStatus>,
}
/// A model registered on a node, with its runtime status.
@@ -40,50 +27,21 @@ pub struct ModelEntry {
}
/// Model lifecycle status.
///
/// `Loading` is a gateway-side synthetic status: neurons never emit it
/// on `/models` (that endpoint only knows about already-loaded handles).
/// The gateway populates it from a neuron's `/health` activation
/// snapshot so the unified `/v1/models` can distinguish "model is
/// catalogued but no one has it" from "model is materialising on
/// neuron N right now". Other status values are reported verbatim by
/// neurons.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelStatus {
Loaded,
Unloaded,
Reloading,
Loading,
}
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
///
/// The first four fields (`id`, `object`, `created`, `owned_by`) match
/// OpenAI's `/v1/models` shape verbatim, so existing OpenAI-aware
/// tooling deserialises this without custom code. The remaining fields
/// are helexa-specific extensions — OpenAI clients ignore unknown
/// fields and other consumers can read them for placement / debugging.
/// Includes which node(s) host this model and their status.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CortexModelEntry {
pub id: String,
/// Always `"model"` per OpenAI's contract.
pub object: String,
/// Unix-second timestamp; cortex stamps this at response time.
pub created: u64,
/// OpenAI's "publisher" field — `"helexa"` for everything we serve.
pub owned_by: String,
/// True if any neuron currently has this model loaded. False for
/// catalogue entries that are feasible but not yet loaded.
pub loaded: bool,
/// Neurons whose discovered topology can satisfy this model's
/// catalogue placement constraints. Empty for models that are
/// loaded somewhere but not present in the catalogue (cortex has
/// no feasibility opinion on those).
pub feasible_on: Vec<String>,
/// Where this model is actually loaded right now. Subset of (or
/// disjoint from) `feasible_on` depending on whether the catalogue
/// covers this model.
/// Which nodes have this model (and their status).
pub locations: Vec<ModelLocation>,
}

View File

@@ -3,7 +3,7 @@
//! These are a subset sufficient for chat completions (streaming + non-streaming).
//! Fields not relevant to proxying are captured as `serde_json::Value` via
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
//! extension field a backend might support.
//! extension field mistral.rs supports.
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -22,7 +22,7 @@ pub struct ChatCompletionRequest {
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
/// All other fields (tools, response_format, backend extensions, etc.)
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
#[serde(flatten)]
pub extra: Value,
}

View File

@@ -1,346 +0,0 @@
//! OpenAI Responses API (`POST /v1/responses`) envelope types.
//!
//! This is OpenAI's newer chat surface, distinct from
//! `/v1/chat/completions` in three ways that matter for us:
//!
//! 1. **Input shape**. Instead of a `messages` array, the request
//! carries `input` — either a plain string (single user turn)
//! or an array of typed items (messages, function calls,
//! function-call outputs, reasoning blocks, …).
//! 2. **Output shape**. The response carries a single `output`
//! array of items, each typed. We always emit one
//! `OutputItem::Message` containing the assistant's reply (plus,
//! when we get there, separate `function_call` items).
//! 3. **Streaming events**. Where chat completions stream
//! structurally-identical `chat.completion.chunk` frames over
//! `data:` lines, Responses streams *named* events
//! (`response.created`, `response.output_text.delta`,
//! `response.completed`, …) over `event:` + `data:` SSE pairs.
//! The wire projector in `neuron::wire::openai_responses` builds
//! these from the same [`crate::openai`]-shaped
//! `InferenceEvent` stream the chat projector consumes.
//!
//! Scope cuts for this first cut:
//!
//! - **`previous_response_id` is rejected at parse time**. Stateful
//! chained conversations need a persistence layer we don't have.
//! - **Reasoning items are accepted-and-ignored** (no Qwen3
//! `<think>` routing yet). Audio and embedded resources are
//! rejected as unsupported.
//! - **Tool calls** (function_call / function_call_output) are
//! carried as round-trip types but the candle harness doesn't
//! emit them yet — wired so the surface is in place for the
//! day we add proper tool-call extraction.
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ── Request ──────────────────────────────────────────────────────────
/// Body of a `POST /v1/responses` request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponsesRequest {
pub model: String,
pub input: ResponsesInput,
/// System-prompt-style instructions. The Responses API
/// separates these from input so a caller doesn't have to
/// build a `system` message item by hand.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(default)]
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
/// Chained-conversation identifier. We don't store responses
/// server-side yet; if this is `Some`, the handler returns 400.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
/// Catch-all for anything we don't model yet (tools, tool_choice,
/// reasoning, response_format, …). Lets a client send a
/// forward-compatible request without our parser rejecting it.
#[serde(flatten)]
pub extra: Value,
}
/// `input` is either a single string or an array of typed items.
/// `#[serde(untagged)]` so the wire shape `"input": "hi"` and
/// `"input": [{...}]` both deserialize.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ResponsesInput {
Text(String),
Items(Vec<ResponsesInputItem>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponsesInputItem {
/// A user / assistant / system turn.
Message {
role: String,
content: ResponsesMessageContent,
},
/// Assistant emitted a tool call. Round-trip only — neuron
/// doesn't synthesise these yet.
FunctionCall {
call_id: String,
name: String,
arguments: String,
},
/// User is feeding a tool result back into the model.
FunctionCallOutput { call_id: String, output: String },
/// Reasoning items emitted by o-series models. Accepted but
/// not forwarded to the model — neuron's candle path doesn't
/// surface reasoning separately yet.
Reasoning {
#[serde(default)]
content: Vec<Value>,
},
}
/// Inside a `Message` item, content is either a plain string or an
/// array of typed parts. Mirrors the chat-completions Parts shape.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ResponsesMessageContent {
Text(String),
Parts(Vec<ResponsesContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponsesContentPart {
/// Plain text inside a user / system turn.
InputText { text: String },
/// An image. `image_url` is either a remote URL or a
/// `data:image/png;base64,…` URI; the request translator just
/// forwards the string.
InputImage {
image_url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
detail: Option<String>,
},
/// Returned text inside an assistant turn — only relevant when
/// the caller is feeding an assistant turn back in to continue
/// a conversation manually (no `previous_response_id`).
OutputText {
text: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
annotations: Vec<Value>,
},
}
// ── Response (non-streaming) ─────────────────────────────────────────
/// Body of a `POST /v1/responses` response.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponsesResponse {
pub id: String,
/// Always `"response"`.
pub object: String,
pub created_at: u64,
/// `"completed"`, `"incomplete"`, or — for the initial event of
/// a streaming response — `"in_progress"`.
pub status: String,
pub model: String,
pub output: Vec<ResponsesOutputItem>,
/// Populated on completion; `None` while streaming.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage: Option<ResponsesUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponsesOutputItem {
Message {
id: String,
/// Always `"assistant"` for model output.
role: String,
/// Output content parts. We always emit a single
/// `OutputText` today; multi-part output would land here
/// once we have e.g. image generation.
content: Vec<ResponsesOutputContent>,
/// Item-level status. `"in_progress"` while streaming the
/// content parts, `"completed"` when done.
#[serde(default = "default_item_status")]
status: String,
},
/// Reserved for the day tool-call extraction lands. The wire
/// shape mirrors `ResponsesInputItem::FunctionCall`.
FunctionCall {
id: String,
call_id: String,
name: String,
arguments: String,
#[serde(default = "default_item_status")]
status: String,
},
}
fn default_item_status() -> String {
"completed".into()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponsesOutputContent {
OutputText {
text: String,
/// Citations / inline annotations. Empty today; reserved
/// for the day we wire in web search / file search.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
annotations: Vec<Value>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponsesUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
}
// ── Streaming event names ────────────────────────────────────────────
/// Event names the SSE projector emits, hoisted as constants so
/// the projector and the wire shape stay in sync without
/// string-typos. The strings are dictated by OpenAI's published
/// Responses API.
pub mod events {
pub const CREATED: &str = "response.created";
/// Fired between `response.created` and the first output-item
/// event. Marks "request validated, model is generating" —
/// some clients use it to differentiate the "warming up" state
/// from "streaming tokens" in their UI.
pub const IN_PROGRESS: &str = "response.in_progress";
pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added";
pub const CONTENT_PART_ADDED: &str = "response.content_part.added";
pub const OUTPUT_TEXT_DELTA: &str = "response.output_text.delta";
pub const OUTPUT_TEXT_DONE: &str = "response.output_text.done";
pub const CONTENT_PART_DONE: &str = "response.content_part.done";
pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done";
pub const COMPLETED: &str = "response.completed";
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialises_input_string_form() {
let raw = r#"{"model": "m", "input": "hello"}"#;
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
match req.input {
ResponsesInput::Text(s) => assert_eq!(s, "hello"),
other => panic!("expected Text, got {other:?}"),
}
}
#[test]
fn deserialises_input_items_form() {
let raw = r#"{
"model": "m",
"input": [
{"type": "message", "role": "user", "content": "hi"}
]
}"#;
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
match req.input {
ResponsesInput::Items(items) => {
assert_eq!(items.len(), 1);
match &items[0] {
ResponsesInputItem::Message { role, content } => {
assert_eq!(role, "user");
match content {
ResponsesMessageContent::Text(t) => assert_eq!(t, "hi"),
other => panic!("expected Text content, got {other:?}"),
}
}
other => panic!("expected Message item, got {other:?}"),
}
}
other => panic!("expected Items, got {other:?}"),
}
}
#[test]
fn deserialises_input_with_image() {
let raw = r#"{
"model": "m",
"input": [
{"type": "message", "role": "user", "content": [
{"type": "input_text", "text": "what is this"},
{"type": "input_image", "image_url": "data:image/png;base64,AAA="}
]}
]
}"#;
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
let items = match req.input {
ResponsesInput::Items(i) => i,
other => panic!("expected Items, got {other:?}"),
};
let parts = match &items[0] {
ResponsesInputItem::Message {
content: ResponsesMessageContent::Parts(p),
..
} => p,
other => panic!("expected Parts, got {other:?}"),
};
assert_eq!(parts.len(), 2);
assert!(matches!(
&parts[0],
ResponsesContentPart::InputText { text } if text == "what is this"
));
assert!(matches!(
&parts[1],
ResponsesContentPart::InputImage { image_url, .. }
if image_url == "data:image/png;base64,AAA="
));
}
#[test]
fn unknown_fields_round_trip_via_extra() {
let raw = r#"{
"model": "m",
"input": "hi",
"tools": [{"type": "web_search"}],
"reasoning": {"effort": "medium"}
}"#;
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
assert!(req.extra.get("tools").is_some());
assert!(req.extra.get("reasoning").is_some());
}
#[test]
fn response_round_trips_through_serde() {
let r = ResponsesResponse {
id: "resp_1".into(),
object: "response".into(),
created_at: 1700,
status: "completed".into(),
model: "m".into(),
output: vec![ResponsesOutputItem::Message {
id: "msg_1".into(),
role: "assistant".into(),
content: vec![ResponsesOutputContent::OutputText {
text: "hi there".into(),
annotations: vec![],
}],
status: "completed".into(),
}],
usage: Some(ResponsesUsage {
input_tokens: 5,
output_tokens: 3,
total_tokens: 8,
}),
};
let json = serde_json::to_string(&r).unwrap();
let parsed: ResponsesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "resp_1");
assert_eq!(parsed.output.len(), 1);
}
}

View File

@@ -24,7 +24,6 @@ tokio-stream.workspace = true
eventsource-stream.workspace = true
bytes = "1"
urlencoding = "2"
url = "2"
[dev-dependencies]
tokio = { workspace = true, features = ["test-util"] }

View File

@@ -20,7 +20,6 @@ pub fn api_routes() -> Router<Arc<CortexState>> {
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/v1/responses", post(responses))
.route("/v1/models", get(list_models))
.route("/v1/messages", post(anthropic_messages))
.route("/health", get(health))
@@ -35,94 +34,23 @@ async fn chat_completions(
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => {
tracing::warn!(
handler = "chat_completions",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "chat_completions",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(404, &e.to_string());
}
Err(e) => return error_response(404, &e.to_string()),
};
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
touch_model(&fleet, &route.node_name, &model_id).await;
let body = rewrite_model_in_body(body, &route.resolved_model_id);
proxy_with_metrics(
&fleet,
&route,
"/v1/chat/completions",
headers,
body,
&route.resolved_model_id,
)
.await
}
/// `POST /v1/responses` — proxy to the appropriate backend node.
///
/// Same routing shape as [`chat_completions`]: extract `model` from
/// the body, resolve to a node, forward verbatim. No translation —
/// neuron speaks the Responses API natively (see
/// `crates/neuron/src/wire/openai_responses.rs`), so the gateway is
/// a pass-through. Streaming and non-streaming are handled
/// identically; the upstream `Content-Type` (text/event-stream vs.
/// application/json) propagates through the proxy.
async fn responses(
State(fleet): State<Arc<CortexState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => {
tracing::warn!(
handler = "responses",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "responses",
model = %model_id,
error = %e,
"route resolve failed"
);
return error_response(404, &e.to_string());
}
};
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
let body = rewrite_model_in_body(body, &route.resolved_model_id);
proxy_with_metrics(
&fleet,
&route,
"/v1/responses",
headers,
body,
&route.resolved_model_id,
&model_id,
)
.await
}
@@ -135,44 +63,17 @@ async fn completions(
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => {
tracing::warn!(
handler = "completions",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "completions",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(404, &e.to_string());
}
Err(e) => return error_response(404, &e.to_string()),
};
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
touch_model(&fleet, &route.node_name, &model_id).await;
let body = rewrite_model_in_body(body, &route.resolved_model_id);
proxy_with_metrics(
&fleet,
&route,
"/v1/completions",
headers,
body,
&route.resolved_model_id,
)
.await
proxy_with_metrics(&fleet, &route, "/v1/completions", headers, body, &model_id).await
}
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
@@ -184,14 +85,7 @@ async fn anthropic_messages(
// Parse as Anthropic request.
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "anthropic_messages",
error = %e,
"rejected: invalid Anthropic request body"
);
return error_response(400, "invalid Anthropic request body");
}
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
};
let model_id = anth_req.model.clone();
@@ -201,43 +95,18 @@ async fn anthropic_messages(
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
let openai_body = match serde_json::to_vec(&openai_req) {
Ok(b) => Bytes::from(b),
Err(e) => {
tracing::error!(
handler = "anthropic_messages",
model = %model_id,
error = %e,
"internal: failed to serialise translated OpenAI request"
);
return error_response(500, "internal translation error");
}
Err(e) => return error_response(500, &format!("translation error: {e}")),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(404, &e.to_string());
}
Err(e) => return error_response(404, &e.to_string()),
};
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
// Swap the alias for the concrete id in the translated body so
// neuron's harness sees a model name that matches what it has
// loaded.
let openai_body = rewrite_model_in_body(openai_body, &route.resolved_model_id);
touch_model(&fleet, &route.node_name, &model_id).await;
let labels = [
("model", route.resolved_model_id.clone()),
("model", model_id.clone()),
("node", route.node_name.clone()),
];
metrics::counter!("cortex_requests_total", &labels).increment(1);
@@ -264,25 +133,14 @@ async fn anthropic_messages(
Ok(resp) => resp,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
// forward_request already warn'd with the wire-level
// detail; no need to log again here.
e.into_response()
}
}
} else {
// Non-streaming: proxy, buffer full response, translate back to Anthropic.
let target_url = format!("{}/v1/chat/completions", route.endpoint);
tracing::info!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
cold_start = route.cold_start,
"proxying request"
);
let upstream_resp = fleet
.http_client
.post(&target_url)
.post(format!("{}/v1/chat/completions", route.endpoint))
.body(openai_body)
.header("content-type", "application/json")
.send()
@@ -292,49 +150,22 @@ async fn anthropic_messages(
Ok(r) => r,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
"upstream request failed (network)"
);
return error_response(502, "upstream request failed");
return error_response(502, &format!("upstream request failed: {e}"));
}
};
let upstream_status = upstream_resp.status();
if !upstream_status.is_success() {
if !upstream_resp.status().is_success() {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
let status = upstream_status.as_u16();
let status = upstream_resp.status().as_u16();
let body = upstream_resp.text().await.unwrap_or_default();
let body_snippet = body.chars().take(512).collect::<String>();
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
status,
body = %body_snippet,
"upstream returned non-2xx"
);
return error_response(status, &format!("upstream returned {status}"));
return error_response(status, &format!("upstream error: {body}"));
}
let body_bytes = match upstream_resp.bytes().await {
Ok(b) => b,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
"failed to read upstream response body"
);
return error_response(502, "failed to read upstream response");
return error_response(502, &format!("failed to read upstream response: {e}"));
}
};
@@ -343,20 +174,7 @@ async fn anthropic_messages(
Ok(r) => r,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
let body_snippet = String::from_utf8_lossy(&body_bytes)
.chars()
.take(512)
.collect::<String>();
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
body = %body_snippet,
"failed to parse upstream response as OpenAI ChatCompletionResponse"
);
return error_response(502, "malformed upstream response");
return error_response(502, &format!("failed to parse upstream response: {e}"));
}
};
@@ -367,62 +185,12 @@ async fn anthropic_messages(
}
}
/// `GET /v1/models` — union of (catalogue × topology feasibility) and
/// (currently loaded somewhere). The result is what the fleet *could*
/// serve, not just what's already loaded — so OpenAI-compatible tools
/// see every model the operator has provisioned, and cortex
/// transparently cold-loads the first time one is requested.
/// `GET /v1/models` — aggregate models from all nodes.
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
use std::collections::HashMap;
let now = Utc::now().timestamp() as u64;
let nodes = fleet.nodes.read().await;
let catalogue = &fleet.catalogue;
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
std::collections::HashMap::new();
let mut entries: HashMap<String, CortexModelEntry> = HashMap::new();
// Pass 1: catalogue × topology. For every catalogue profile, find
// healthy neurons whose discovered devices satisfy the profile.
// Catalogue-defined models surface here even if nothing has loaded
// them yet — that's the point of the unified endpoint.
for profile in &catalogue.models {
let mut feasible_on = Vec::new();
for node in nodes.values() {
if !node.healthy {
continue;
}
let Some(disc) = node.discovery.as_ref() else {
continue;
};
if profile.is_feasible_on(&node.name, &disc.devices) {
feasible_on.push(node.name.clone());
}
}
if feasible_on.is_empty() {
// The catalogue lists this model but no neuron's topology
// matches — surface it as not-loaded with no feasible
// location. Hides nothing; lets operators see why a
// configured model isn't reachable.
feasible_on.clear();
}
entries.insert(
profile.id.clone(),
CortexModelEntry {
id: profile.id.clone(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: false,
feasible_on,
locations: Vec::new(),
},
);
}
// Pass 2: layer the actually-loaded state on top. For each
// (node, model) entry, attach a ModelLocation. If the model isn't
// in the catalogue, create a new CortexModelEntry from scratch —
// cortex doesn't refuse to surface a manually-loaded model just
// because the operator didn't enumerate it in models.toml.
for node in nodes.values() {
for (model_id, entry) in &node.models {
let location = ModelLocation {
@@ -430,108 +198,19 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
status: entry.status,
vram_estimate_mb: entry.vram_estimate_mb,
};
let was_loaded = matches!(entry.status, cortex_core::node::ModelStatus::Loaded);
entries
model_map
.entry(model_id.clone())
.and_modify(|e| {
e.locations.push(location.clone());
if was_loaded {
e.loaded = true;
}
})
.and_modify(|e| e.locations.push(location.clone()))
.or_insert_with(|| CortexModelEntry {
id: model_id.clone(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: was_loaded,
// Not in catalogue — cortex has no opinion on
// feasibility; leave empty.
feasible_on: Vec::new(),
locations: vec![location],
});
}
}
// Pass 3: surface pre-warming models. Each neuron's `/health`
// activation snapshot (polled separately from /models) reports
// `in_progress` (the model currently materialising) and `pending`
// (queued behind it). Neither appears on the neuron's `/models`
// yet — that endpoint only knows about fully-loaded handles — so
// without this pass a client polling `/v1/models` during pre-warm
// sees Qwen3.6-27B with no location and concludes "not there".
// Synthesising a Loading location instead tells clients the model
// is on its way. Idempotent against Pass 2: if a Loading location
// for this node already exists (shouldn't, but be safe) we skip.
for node in nodes.values() {
let Some(activation) = node.activation.as_ref() else {
continue;
};
let mut loading_ids: Vec<&str> = Vec::new();
if let Some(id) = activation.in_progress.as_deref() {
loading_ids.push(id);
}
for id in &activation.pending {
loading_ids.push(id.as_str());
}
for model_id in loading_ids {
let location = ModelLocation {
node: node.name.clone(),
status: cortex_core::node::ModelStatus::Loading,
vram_estimate_mb: None,
};
entries
.entry(model_id.to_string())
.and_modify(|e| {
let already = e.locations.iter().any(|l| {
l.node == node.name && l.status == cortex_core::node::ModelStatus::Loading
});
if !already {
e.locations.push(location.clone());
}
})
.or_insert_with(|| CortexModelEntry {
id: model_id.to_string(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: false,
feasible_on: Vec::new(),
locations: vec![location],
});
}
}
let data: Vec<Value> = model_map.values().map(|e| json!(e)).collect();
// Pass 4: surface aliases as their own entries pointing at the
// same locations as the target id, so a client browsing /v1/models
// sees "helexa/small" / "helexa/balanced" / "helexa/large" (or
// whatever the operator defined) and can request inference
// against them directly. Aliases that point at unknown targets
// are skipped — surfacing a dead alias would be misleading.
for (alias, target) in &catalogue.aliases {
let Some(target_entry) = entries.get(target).cloned() else {
tracing::warn!(
alias = alias,
target = target,
"alias points at a model not present in catalogue or fleet; skipping"
);
continue;
};
entries.insert(
alias.clone(),
CortexModelEntry {
id: alias.clone(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: target_entry.loaded,
feasible_on: target_entry.feasible_on,
locations: target_entry.locations,
},
);
}
let data: Vec<Value> = entries.values().map(|e| json!(e)).collect();
Json(json!({
"object": "list",
"data": data,
@@ -586,9 +265,6 @@ async fn proxy_with_metrics(
}
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
// proxy::forward_request already warn'd with wire-level
// detail (target URL, error, status). ProxyError::into_response
// now returns a generic message — no body leak.
e.into_response()
}
}
@@ -609,38 +285,6 @@ fn extract_model(body: &[u8]) -> Option<String> {
v.get("model")?.as_str().map(|s| s.to_string())
}
/// Rewrite the `model` field of an OpenAI-style JSON request body to
/// the resolved concrete id. Returns the original bytes if `new_model`
/// matches what's already there or the body fails to parse — the
/// caller has already extracted `model` via `extract_model`, so a
/// parse failure here would only happen on a body the client crafted
/// to defeat us, and we'd rather proxy it unchanged than 500.
///
/// Needed because neuron rejects requests whose `model` field doesn't
/// match a loaded model, so a client that sends `model: "helexa/small"`
/// would hit a 404 at the harness unless we swap it for the concrete
/// id the alias resolved to.
fn rewrite_model_in_body(body: Bytes, new_model: &str) -> Bytes {
let Ok(mut v) = serde_json::from_slice::<Value>(&body) else {
return body;
};
let needs_rewrite = v
.get("model")
.and_then(|m| m.as_str())
.map(|m| m != new_model)
.unwrap_or(false);
if !needs_rewrite {
return body;
}
if let Value::Object(obj) = &mut v {
obj.insert("model".into(), Value::String(new_model.to_string()));
}
match serde_json::to_vec(&v) {
Ok(bytes) => Bytes::from(bytes),
Err(_) => body,
}
}
fn error_response(status: u16, message: &str) -> Response {
let code = axum::http::StatusCode::from_u16(status)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);

View File

@@ -3,7 +3,6 @@
use crate::state::CortexState;
use chrono::Utc;
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
use cortex_core::harness::ModelInfo;
use cortex_core::node::{ModelEntry, ModelStatus};
use std::sync::Arc;
@@ -26,59 +25,7 @@ pub async fn poll_once(fleet: &CortexState) {
}
}
/// One-shot fetch of `GET /discovery`. Cached on the NodeState forever
/// after the first success — topology is invariant for a given neuron
/// process. Skipped when the cache is already populated.
async fn maybe_poll_discovery(fleet: &CortexState, name: &str, endpoint: &str) {
{
let nodes = fleet.nodes.read().await;
match nodes.get(name) {
Some(n) if n.discovery.is_some() => return,
_ => {}
}
}
let url = format!("{endpoint}/discovery");
let resp = match fleet
.http_client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await
{
Ok(r) if r.status().is_success() => r,
Ok(r) => {
tracing::debug!(node = name, status = %r.status(), "discovery probe non-success");
return;
}
Err(e) => {
tracing::debug!(node = name, error = %e, "discovery probe unreachable");
return;
}
};
match resp.json::<DiscoveryResponse>().await {
Ok(d) => {
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(name) {
tracing::info!(
node = name,
hostname = %d.hostname,
devices = d.devices.len(),
"discovery cached"
);
node.discovery = Some(d);
}
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to parse /discovery response");
}
}
}
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
// Topology first — cheap once cached, and the router needs it to
// route requests against catalogue entries that aren't loaded yet.
maybe_poll_discovery(fleet, name, endpoint).await;
let url = format!("{endpoint}/models");
let result = fleet
@@ -142,51 +89,6 @@ async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
node.healthy = false;
}
}
// Release the write lock before the next HTTP call.
drop(nodes);
// Poll /health for the activation snapshot. We don't want this to
// flip the node to unhealthy on its own — a neuron that's serving
// /models fine is still operational even if /health is briefly
// unavailable — so failures are debug-level and leave the existing
// activation reading in place.
poll_health(fleet, name, endpoint).await;
}
/// Fetch `/health` and stash the activation snapshot on NodeState.
/// Decoupled from the /models poll so a /health glitch doesn't mark
/// the neuron unhealthy or evict the model list.
async fn poll_health(fleet: &CortexState, name: &str, endpoint: &str) {
let url = format!("{endpoint}/health");
let resp = match fleet
.http_client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await
{
Ok(r) if r.status().is_success() => r,
Ok(r) => {
tracing::debug!(node = name, status = %r.status(), "/health probe non-success");
return;
}
Err(e) => {
tracing::debug!(node = name, error = %e, "/health probe failed");
return;
}
};
match resp.json::<HealthResponse>().await {
Ok(h) => {
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(name) {
node.activation = Some(h.activation);
}
}
Err(e) => {
tracing::debug!(node = name, error = %e, "failed to parse /health response");
}
}
}
fn parse_status(s: &str) -> ModelStatus {
@@ -194,7 +96,6 @@ fn parse_status(s: &str) -> ModelStatus {
"loaded" => ModelStatus::Loaded,
"unloaded" => ModelStatus::Unloaded,
"reloading" => ModelStatus::Reloading,
"loading" => ModelStatus::Loading,
_ => ModelStatus::Loaded,
}
}

View File

@@ -1,4 +1,4 @@
//! Streaming HTTP reverse proxy to neuron backends.
//! Streaming HTTP reverse proxy to mistral.rs backends.
//!
//! For streaming requests, SSE chunks are forwarded as they arrive.
//! The proxy captures timing information for metrics but does not
@@ -12,13 +12,6 @@ use axum::response::{IntoResponse, Response};
use reqwest::Client;
/// Proxy a request body to the resolved backend node and stream the response.
///
/// Logging contract: every call emits exactly one structured event at
/// info / warn level for operator visibility, regardless of outcome.
/// Network-level failures and non-2xx upstream statuses are warn'd here
/// (closest to the wire); the user-facing response carries only the
/// status code and a generic message — implementation detail (body,
/// error chain) lives in the log, never in the API surface.
pub async fn forward_request(
client: &Client,
route: &RouteDecision,
@@ -44,33 +37,10 @@ pub async fn forward_request(
req_builder = req_builder.header(key, value);
}
let upstream_resp = match req_builder.send().await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
node = %route.node_name,
url = %url,
error = %e,
"proxy: upstream request failed (network)"
);
return Err(ProxyError::Upstream(e));
}
};
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
let upstream_status = upstream_resp.status();
if !upstream_status.is_success() {
// Streaming body — can't snippet without breaking the stream
// pass-through. Log status + URL; the client still gets the
// upstream status, just without the leaked body.
tracing::warn!(
node = %route.node_name,
url = %url,
status = upstream_status.as_u16(),
"proxy: upstream returned non-2xx"
);
}
let status = StatusCode::from_u16(upstream_status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let status =
StatusCode::from_u16(upstream_resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let resp_headers = upstream_resp.headers().clone();
let stream = upstream_resp.bytes_stream();
@@ -82,37 +52,28 @@ pub async fn forward_request(
response = response.header(key, value);
}
response.body(body).map_err(|e| {
tracing::warn!(
node = %route.node_name,
url = %url,
error = %e,
"proxy: failed to build response"
);
ProxyError::ResponseBuild(e.to_string())
})
response
.body(body)
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
}
#[derive(Debug, thiserror::Error)]
pub enum ProxyError {
#[error("upstream request failed")]
#[error("upstream request failed: {0}")]
Upstream(reqwest::Error),
#[error("failed to build response")]
#[error("failed to build response: {0}")]
ResponseBuild(String),
}
impl IntoResponse for ProxyError {
fn into_response(self) -> Response {
let (status, message) = match &self {
ProxyError::Upstream(_) => (StatusCode::BAD_GATEWAY, "upstream request failed"),
ProxyError::ResponseBuild(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"failed to build response",
),
let status = match &self {
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
let body = serde_json::json!({
"error": {
"message": message,
"message": self.to_string(),
"type": "proxy_error",
}
});

View File

@@ -2,21 +2,13 @@
//!
//! Given a model ID from an inbound request, determine which node should
//! handle it. Priority:
//! 1. Node where the model is currently `Loaded` → use it.
//! 2. Node where the model is `Unloaded` → use it; neuron's existing
//! lazy-load behaviour will reload before serving the request.
//! 3. Model is in the catalogue → pick a feasible neuron, call
//! `POST /models/load`, wait for the load to complete, then
//! proxy. First-request cold-load latency is acceptable per the
//! unified-endpoint contract.
//! 4. Not in catalogue, not loaded anywhere → 404.
//! 1. Node where the model is currently `Loaded`
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
//! 3. Error: model not found on any node
use crate::state::CortexState;
use cortex_core::catalogue::ModelProfile;
use cortex_core::harness::ModelSpec;
use cortex_core::node::ModelStatus;
use std::sync::Arc;
use std::time::Duration;
/// The routing decision: which node endpoint to proxy the request to.
#[derive(Debug, Clone)]
@@ -24,292 +16,62 @@ pub struct RouteDecision {
pub node_name: String,
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
pub endpoint: String,
/// Whether the model will need to load (cold start). Set to true
/// when we proxied to an `Unloaded` node (lazy load on neuron) or
/// when we just triggered an explicit cold-load via the catalogue
/// path.
/// Whether the model will need to load (cold start).
pub cold_start: bool,
/// The concrete model id we actually routed to. Equal to the
/// caller's requested id unless an alias was resolved (e.g. caller
/// asked for `helexa/small`, this carries `Qwen/Qwen3-1.7B`). The
/// handler uses this to rewrite the request body's `model` field
/// before proxying — neurons reject requests where the body's
/// model name doesn't match a loaded model.
pub resolved_model_id: String,
}
#[derive(Debug, thiserror::Error)]
pub enum RouteError {
#[error("model '{0}' not found on any node and not in catalogue")]
#[error("model '{0}' not found on any node")]
ModelNotFound(String),
#[error("no healthy nodes available")]
NoHealthyNodes,
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
EndpointResolveFailed(String, String),
#[error(
"model '{model_id}' is in the catalogue but no healthy neuron's topology satisfies its constraints"
)]
NoFeasibleNeuron { model_id: String },
#[error("cold-load of '{model_id}' on '{node}' failed: {message}")]
ColdLoadFailed {
model_id: String,
node: String,
message: String,
},
}
/// Resolve which node should serve a request for the given model.
/// Asks the neuron for the inference endpoint after selecting a node.
pub async fn resolve(
fleet: &Arc<CortexState>,
requested_model_id: &str,
model_id: &str,
) -> Result<RouteDecision, RouteError> {
// Alias resolution first — swap `helexa/small` (etc.) for the
// concrete id before any node lookups so the rest of routing,
// loading, and metrics deal in concrete ids only. `resolve_alias`
// returns the input verbatim when it isn't an alias.
let model_id = fleet.catalogue.resolve_alias(requested_model_id);
if model_id != requested_model_id {
tracing::debug!(
requested = requested_model_id,
resolved = model_id,
"alias resolved"
);
}
// Snapshot loaded / unloaded state from the poller cache.
let (loaded_route, unloaded_route, any_healthy) = {
let (node_name, neuron_endpoint, cold_start) = {
let nodes = fleet.nodes.read().await;
let mut loaded_route = None;
let mut unloaded_route = None;
let mut any_healthy = false;
let mut loaded_candidate = None;
let mut unloaded_candidate = None;
for node in nodes.values() {
if !node.healthy {
continue;
}
any_healthy = true;
if let Some(entry) = node.models.get(model_id) {
match entry.status {
ModelStatus::Loaded | ModelStatus::Reloading => {
loaded_route = Some((node.name.clone(), node.endpoint.clone(), false));
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false));
break;
}
ModelStatus::Unloaded => {
if unloaded_route.is_none() {
unloaded_route = Some((node.name.clone(), node.endpoint.clone(), true));
if unloaded_candidate.is_none() {
unloaded_candidate =
Some((node.name.clone(), node.endpoint.clone(), true));
}
}
// Loading is gateway-synthesised from neuron's
// activation snapshot; it never appears on the
// wire from neuron's `/models`. Skip — the model
// isn't actually servable yet. The pre-existing
// race (catalogue cold_load fires a parallel
// /models/load against the in-flight load) is no
// worse than before; fixing it needs neuron-side
// in-flight tracking on /models/load itself.
ModelStatus::Loading => {}
}
}
}
(loaded_route, unloaded_route, any_healthy)
};
if !any_healthy {
return Err(RouteError::NoHealthyNodes);
}
// Priority 1: already loaded.
if let Some((node_name, neuron_endpoint, cold_start)) = loaded_route {
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
}
// Priority 2: known to neuron but unloaded (neuron's lazy load).
if let Some((node_name, neuron_endpoint, cold_start)) = unloaded_route {
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
}
// Priority 3: catalogue × topology cold-load.
if let Some(profile) = fleet.catalogue.get(model_id) {
let (node_name, neuron_endpoint) = pick_feasible_neuron(fleet, profile).await?;
cold_load(fleet, &node_name, &neuron_endpoint, profile).await?;
return finish(fleet, &node_name, &neuron_endpoint, model_id, true).await;
}
Err(RouteError::ModelNotFound(model_id.to_string()))
}
/// Pick a healthy neuron whose discovered topology satisfies the
/// profile. Preference order:
/// 1. A neuron from `profile.pinned_on` that is healthy + feasible.
/// 2. Otherwise, any healthy + feasible neuron, stable by name.
async fn pick_feasible_neuron(
fleet: &Arc<CortexState>,
profile: &ModelProfile,
) -> Result<(String, String), RouteError> {
let nodes = fleet.nodes.read().await;
let mut candidates: Vec<(String, String, bool)> = Vec::new();
for node in nodes.values() {
if !node.healthy {
continue;
}
let Some(disc) = node.discovery.as_ref() else {
continue;
};
if !profile.is_feasible_on(&node.name, &disc.devices) {
continue;
}
let pinned = profile.pinned_on.iter().any(|n| n == &node.name);
candidates.push((node.name.clone(), node.endpoint.clone(), pinned));
}
candidates.sort_by(|a, b| {
b.2.cmp(&a.2) // pinned first (true > false)
.then(a.0.cmp(&b.0))
});
let pick = candidates.into_iter().next();
pick.map(|(n, e, _)| (n, e))
.ok_or_else(|| RouteError::NoFeasibleNeuron {
model_id: profile.id.clone(),
})
}
/// Issue `POST {endpoint}/models/load` for this profile on this neuron,
/// blocking until the load completes (neuron's load endpoint is
/// synchronous — it returns 200 once VRAM is materialised). On success
/// also inserts a `Loaded` entry into the local NodeState cache so the
/// caller's subsequent endpoint lookup sees the new model without
/// waiting for the next poll cycle.
async fn cold_load(
fleet: &Arc<CortexState>,
node_name: &str,
neuron_endpoint: &str,
profile: &ModelProfile,
) -> Result<(), RouteError> {
let spec = profile_to_spec(fleet, node_name, profile).await;
let url = format!("{neuron_endpoint}/models/load");
tracing::info!(model = %profile.id, node = node_name, "cold-loading via /models/load");
// Generous timeout: a fresh download + safetensors mmap + device
// copy for a 30B-class dense model can comfortably exceed 5 min on
// a slow link. The HTTP client's own default already covers most
// of this; pin a longer per-request bound just here.
let resp = match fleet
.http_client
.post(&url)
.timeout(Duration::from_secs(1800))
.json(&spec)
.send()
.await
{
Ok(r) => r,
Err(e) => {
return Err(RouteError::ColdLoadFailed {
model_id: profile.id.clone(),
node: node_name.to_string(),
message: format!("HTTP request failed: {e}"),
});
}
};
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
// Neuron returns 400 "already loaded" when two concurrent
// requests race the same model. Treat that as success — both
// requests effectively achieved the same end state.
if body.contains("already loaded") {
tracing::info!(
model = %profile.id,
node = node_name,
"cold-load saw 'already loaded' — treating as success"
);
} else {
return Err(RouteError::ColdLoadFailed {
model_id: profile.id.clone(),
node: node_name.to_string(),
message: format!("HTTP {status}: {body}"),
});
}
} else {
tracing::info!(model = %profile.id, node = node_name, "cold-load returned 200");
}
// Warm the cache: insert a Loaded ModelEntry so the next
// resolve() finds the model without waiting for the poll loop.
{
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(node_name) {
node.models.insert(
profile.id.clone(),
cortex_core::node::ModelEntry {
id: profile.id.clone(),
status: ModelStatus::Loaded,
last_accessed: Some(chrono::Utc::now()),
vram_estimate_mb: profile.vram_mb,
},
);
}
}
Ok(())
}
/// Translate a `ModelProfile` to a `ModelSpec` neuron's /models/load
/// accepts. Devices are picked from the neuron's discovered topology —
/// the first `min_devices` indices that meet `min_device_vram_mb`.
async fn profile_to_spec(
fleet: &Arc<CortexState>,
node_name: &str,
profile: &ModelProfile,
) -> ModelSpec {
let devices = {
let nodes = fleet.nodes.read().await;
let mut picked: Vec<u32> = Vec::new();
if let Some(node) = nodes.get(node_name)
&& let Some(disc) = &node.discovery
{
let min_vram = profile.min_device_vram_mb.unwrap_or(0);
for d in &disc.devices {
if d.vram_total_mb >= min_vram {
picked.push(d.index);
if picked.len() as u32 >= profile.min_devices {
break;
}
}
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
if nodes.values().any(|n| n.healthy) {
RouteError::ModelNotFound(model_id.to_string())
} else {
RouteError::NoHealthyNodes
}
}
if picked.is_empty() {
// Fall back to a 0..min_devices default; pick_feasible_neuron
// already verified the topology satisfies the constraints,
// so this only fires if discovery raced or was lost.
(0..profile.min_devices).collect()
} else {
picked
}
})?
};
let tensor_parallel = if profile.min_devices > 1 {
Some(profile.min_devices)
} else {
None
};
ModelSpec {
model_id: profile.id.clone(),
harness: profile.harness.clone(),
quant: profile.quant.clone(),
tensor_parallel,
devices: Some(devices),
}
}
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
/// build the final `RouteDecision`. Shared by all three priority
/// branches above.
async fn finish(
fleet: &Arc<CortexState>,
node_name: &str,
neuron_endpoint: &str,
model_id: &str,
cold_start: bool,
) -> Result<RouteDecision, RouteError> {
// Ask the neuron for the inference endpoint for this model.
let endpoint_url = format!(
"{}/models/{}/endpoint",
neuron_endpoint,
@@ -327,83 +89,13 @@ async fn finish(
_ => None,
};
let raw = inference_endpoint.ok_or_else(|| {
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.to_string())
let endpoint = inference_endpoint.ok_or_else(|| {
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone())
})?;
// Rewrite loopback inference URLs to use the configured neuron host.
// Neuron's default bind_url is `http://localhost:13131` (it can't
// reliably know its own externally-resolvable name). Cortex sees a
// URL that's only meaningful from the neuron host's own perspective;
// proxying directly to localhost from a different cortex host would
// hit nothing. Keep neuron's port and path (a future harness could
// serve inference on a different port than the management API), but
// swap the host for the one in cortex.toml.
let endpoint = rewrite_loopback_host(&raw, neuron_endpoint).unwrap_or(raw);
Ok(RouteDecision {
node_name: node_name.to_string(),
node_name,
endpoint,
cold_start,
resolved_model_id: model_id.to_string(),
})
}
/// If `inference_url`'s host is a loopback name (localhost / 127.0.0.1 /
/// 0.0.0.0 / ::1), return a copy with the host replaced by
/// `neuron_endpoint`'s host. Otherwise return None and the caller falls
/// back to the inference URL as-is.
fn rewrite_loopback_host(inference_url: &str, neuron_endpoint: &str) -> Option<String> {
let inf = url::Url::parse(inference_url).ok()?;
let inf_host = inf.host_str()?;
let is_loopback = matches!(inf_host, "localhost" | "127.0.0.1" | "0.0.0.0" | "::1");
if !is_loopback {
return None;
}
let neuron = url::Url::parse(neuron_endpoint).ok()?;
let new_host = neuron.host_str()?;
let mut out = inf.clone();
out.set_host(Some(new_host)).ok()?;
// url::Url::to_string normalises an empty path to "/", which then
// breaks downstream callers that do format!("{endpoint}/v1/...")
// and produce a double slash. The proxy URL is treated as a base
// string that the caller appends paths to, so strip the trailing
// slash here.
let s = out.to_string();
Some(s.trim_end_matches('/').to_string())
}
#[cfg(test)]
mod tests {
use super::rewrite_loopback_host;
#[test]
fn rewrites_localhost_keeps_port_and_path() {
let out = rewrite_loopback_host(
"http://localhost:13131",
"http://beast.hanzalova.internal:13131",
);
assert_eq!(
out.as_deref(),
Some("http://beast.hanzalova.internal:13131")
);
}
#[test]
fn rewrites_loopback_with_distinct_inference_port() {
let out = rewrite_loopback_host("http://127.0.0.1:8080", "http://beast.lan:13131");
assert_eq!(out.as_deref(), Some("http://beast.lan:8080"));
}
#[test]
fn leaves_non_loopback_alone() {
let out = rewrite_loopback_host("http://other.host:1234", "http://beast.lan:13131");
assert_eq!(out, None);
}
#[test]
fn malformed_inference_url_returns_none() {
let out = rewrite_loopback_host("not a url", "http://beast.lan:13131");
assert_eq!(out, None);
}
}

View File

@@ -26,8 +26,6 @@ impl CortexState {
models: HashMap::new(),
lifecycle_cycles: 0,
last_poll: None,
discovery: None,
activation: None,
},
);
}

View File

@@ -1,265 +0,0 @@
//! Alias resolution: a client request with `model: "helexa/small"`
//! routes to the concrete model id (e.g. `Qwen/Qwen3-1.7B`), with the
//! proxied request body rewritten so the upstream neuron sees a model
//! name that matches its loaded handle.
mod common;
use cortex_core::config::{
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
};
use cortex_core::node::{ModelEntry, ModelStatus};
use cortex_gateway::state::CortexState;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::net::TcpListener;
/// Write a `models.toml` with one alias to a unique temp path. Returns
/// the path; the file persists for the test process and gets reaped by
/// the OS at exit. Using $XDG_RUNTIME_DIR fallback for the temp dir
/// keeps the file off shared /tmp on CI without pulling in tempfile.
fn write_models_toml(alias: &str, target: &str) -> PathBuf {
let contents = format!(
r#"
[aliases]
"{alias}" = "{target}"
"#
);
let mut path = std::env::temp_dir();
let pid = std::process::id();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
path.push(format!("cortex-test-models-{pid}-{now}.toml"));
std::fs::write(&path, contents).expect("write temp models.toml");
path
}
#[tokio::test]
async fn test_alias_resolves_in_chat_completions() {
let mock_url = common::spawn_mock_neuron().await;
let models_path = write_models_toml("helexa/small", "test-model");
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![NeuronEndpoint {
name: "mock-node".into(),
endpoint: mock_url,
}],
models_config: models_path.to_string_lossy().to_string(),
};
let fleet = Arc::new(CortexState::from_config(&config));
// Seed the node as healthy with the concrete model loaded under
// the target id. The poller doesn't run in this test; we just
// populate state manually.
{
let mut nodes = fleet.nodes.write().await;
let node = nodes.get_mut("mock-node").expect("node must exist");
node.healthy = true;
node.models.insert(
"test-model".into(),
ModelEntry {
id: "test-model".into(),
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: None,
},
);
}
// Sanity: the catalogue actually picked up the alias.
assert_eq!(
fleet.catalogue.resolve_alias("helexa/small"),
"test-model",
"alias should resolve to target id"
);
// Spawn the gateway against this fleet.
let app = cortex_gateway::build_app(Arc::clone(&fleet));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let gateway_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let gateway_url = format!("http://{gateway_addr}");
// Send a chat completion against the alias. The mock backend
// echoes back the `model` field it received — so a body whose
// model wasn't rewritten would come back as "helexa/small", and a
// properly-rewritten one as "test-model".
let client = reqwest::Client::new();
let resp = client
.post(format!("{gateway_url}/v1/chat/completions"))
.json(&json!({
"model": "helexa/small",
"messages": [{"role": "user", "content": "hi"}],
}))
.send()
.await
.expect("gateway should respond");
assert!(resp.status().is_success(), "gateway returned non-2xx");
let body: serde_json::Value = resp.json().await.expect("response is JSON");
assert_eq!(
body.get("model").and_then(|m| m.as_str()),
Some("test-model"),
"mock backend should have seen the resolved model id, not the alias"
);
}
#[tokio::test]
async fn test_aliases_surface_in_v1_models() {
let mock_url = common::spawn_mock_neuron().await;
let models_path = write_models_toml("helexa/small", "test-model");
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![NeuronEndpoint {
name: "mock-node".into(),
endpoint: mock_url,
}],
models_config: models_path.to_string_lossy().to_string(),
};
let fleet = Arc::new(CortexState::from_config(&config));
// Seed the target as loaded so the alias's mirrored entry shows
// loaded=true.
{
let mut nodes = fleet.nodes.write().await;
let node = nodes.get_mut("mock-node").expect("node must exist");
node.healthy = true;
node.models.insert(
"test-model".into(),
ModelEntry {
id: "test-model".into(),
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: Some(2000),
},
);
}
let app = cortex_gateway::build_app(Arc::clone(&fleet));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let gateway_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let gateway_url = format!("http://{gateway_addr}");
let resp = reqwest::get(format!("{gateway_url}/v1/models"))
.await
.expect("gateway should respond");
let body: serde_json::Value = resp.json().await.unwrap();
let entries = body
.get("data")
.and_then(|d| d.as_array())
.expect("data array");
// Both the alias and the target should be present.
let ids: Vec<&str> = entries
.iter()
.filter_map(|e| e.get("id").and_then(|v| v.as_str()))
.collect();
assert!(ids.contains(&"test-model"), "target should be listed");
assert!(ids.contains(&"helexa/small"), "alias should be listed");
// The alias's `loaded` flag and locations should mirror the target.
let alias_entry = entries
.iter()
.find(|e| e.get("id").and_then(|v| v.as_str()) == Some("helexa/small"))
.expect("alias entry");
assert_eq!(alias_entry.get("loaded"), Some(&json!(true)));
let locations = alias_entry
.get("locations")
.and_then(|l| l.as_array())
.expect("locations array");
assert_eq!(locations.len(), 1);
assert_eq!(
locations[0].get("node").and_then(|n| n.as_str()),
Some("mock-node")
);
}
#[tokio::test]
async fn test_alias_falls_through_for_unmapped_model() {
// Catalogue has an alias for some-other-thing but the request
// model "test-model" isn't an alias; resolution should be a no-op.
let mock_url = common::spawn_mock_neuron().await;
let models_path = write_models_toml("helexa/large", "definitely-not-loaded");
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![NeuronEndpoint {
name: "mock-node".into(),
endpoint: mock_url,
}],
models_config: models_path.to_string_lossy().to_string(),
};
let fleet = Arc::new(CortexState::from_config(&config));
{
let mut nodes = fleet.nodes.write().await;
let node = nodes.get_mut("mock-node").expect("node must exist");
node.healthy = true;
node.models.insert(
"test-model".into(),
ModelEntry {
id: "test-model".into(),
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: None,
},
);
}
let app = cortex_gateway::build_app(Arc::clone(&fleet));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let gateway_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let gateway_url = format!("http://{gateway_addr}");
let resp = reqwest::Client::new()
.post(format!("{gateway_url}/v1/chat/completions"))
.json(&json!({
"model": "test-model",
"messages": [{"role": "user", "content": "hi"}],
}))
.send()
.await
.unwrap();
assert!(resp.status().is_success());
let body: serde_json::Value = resp.json().await.unwrap();
assert_eq!(
body.get("model").and_then(|m| m.as_str()),
Some("test-model")
);
}

View File

@@ -22,7 +22,6 @@ use tokio::net::TcpListener;
/// - GET /models/:id/endpoint (returns the inference URL)
/// - POST /models/unload (accepts unload requests)
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
///
/// Returns the neuron base URL.
pub async fn spawn_mock_neuron() -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
@@ -44,7 +43,6 @@ pub async fn spawn_mock_neuron() -> String {
post(|Json(_body): Json<Value>| async { Json(json!({"status": "unloaded"})) }),
)
.route("/v1/chat/completions", post(mock_chat_completions))
.route("/v1/responses", post(mock_responses))
.route("/v1/models", get(mock_v1_models));
tokio::spawn(async move {
@@ -56,7 +54,7 @@ pub async fn spawn_mock_neuron() -> String {
async fn mock_neuron_list_models() -> Json<Value> {
Json(json!([
{"id": "test-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
{"id": "test-model", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
]))
}
@@ -94,39 +92,6 @@ async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
}))
}
async fn mock_responses(Json(body): Json<Value>) -> Json<Value> {
let model = body
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
// Echo the model field back and synthesise a tiny ResponsesResponse.
// Mirrors the shape neuron's /v1/responses handler emits so the
// gateway test only needs to assert the proxy round-tripped it.
Json(json!({
"id": "resp-test-001",
"object": "response",
"created_at": 1700000000_u64,
"status": "completed",
"model": model,
"output": [{
"type": "message",
"id": "msg-test-001",
"role": "assistant",
"content": [{
"type": "output_text",
"text": "Hello from mock backend",
"annotations": []
}],
"status": "completed"
}],
"usage": {
"input_tokens": 5,
"output_tokens": 5,
"total_tokens": 10
}
}))
}
/// Spawns a mock neuron that returns SSE streaming responses for chat completions.
pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Duration) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
@@ -198,33 +163,6 @@ pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Durati
/// Spawns a mock neuron with a custom models list.
pub async fn spawn_mock_neuron_with_models(models_response: Value) -> String {
spawn_mock_neuron_with_models_and_health(models_response, default_health_response()).await
}
/// Default `/health` response used by mocks that don't care about the
/// activation field — empty devices, no in-flight pre-warm, state=ready.
pub fn default_health_response() -> Value {
json!({
"uptime_secs": 0,
"devices": [],
"activation": {
"state": "ready",
"pending": [],
"in_progress": null,
"completed": [],
"failed": []
}
})
}
/// Variant of `spawn_mock_neuron_with_models` that also serves a
/// `/health` body. Used by tests that drive the gateway's activation
/// surface (poller reading /health, /v1/models synthesising Loading
/// locations from in_progress / pending).
pub async fn spawn_mock_neuron_with_models_and_health(
models_response: Value,
health_response: Value,
) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let base_url = format!("http://{addr}");
@@ -238,13 +176,6 @@ pub async fn spawn_mock_neuron_with_models_and_health(
async move { Json(resp) }
}),
)
.route(
"/health",
get(move || {
let resp = health_response.clone();
async move { Json(resp) }
}),
)
.route(
"/models/{model_id}/endpoint",
get(move |Path(_model_id): Path<String>| {

View File

@@ -12,8 +12,8 @@ use std::sync::Arc;
async fn test_poller_discovers_models() {
// Mock neuron reports 2 models via /models endpoint (neuron format).
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "model-a", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
{"id": "model-b", "harness": "candle", "status": "unloaded", "devices": [], "vram_used_mb": null}
{"id": "model-a", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
{"id": "model-b", "harness": "mistralrs", "status": "unloaded", "devices": [], "vram_used_mb": null}
]))
.await;
@@ -63,8 +63,8 @@ async fn test_poller_discovers_models() {
#[tokio::test]
async fn test_poller_updates_gateway_models_endpoint() {
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "model-x", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "model-y", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null}
{"id": "model-x", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "model-y", "harness": "mistralrs", "status": "loaded", "devices": [1], "vram_used_mb": null}
]))
.await;
@@ -152,8 +152,8 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
#[tokio::test]
async fn test_poller_removes_stale_models() {
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "drop-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "drop-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
]))
.await;
@@ -183,7 +183,7 @@ async fn test_poller_removes_stale_models() {
// New mock with only one model.
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
]))
.await;
@@ -237,58 +237,3 @@ async fn test_poller_removes_stale_models() {
assert!(node.models.contains_key("keep-me"));
assert!(!node.models.contains_key("drop-me"));
}
#[tokio::test]
async fn test_poller_captures_activation_from_health() {
// Mock neuron is mid-prewarm: /models reports nothing (the loading
// model hasn't been inserted into the harness map yet), but
// /health's activation says model-x is in_progress and model-y is
// queued behind it.
let mock_url = common::spawn_mock_neuron_with_models_and_health(
json!([]),
json!({
"uptime_secs": 30,
"devices": [],
"activation": {
"state": "pre_warming",
"pending": ["Qwen/model-y"],
"in_progress": "Qwen/model-x",
"completed": [],
"failed": []
}
}),
)
.await;
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![NeuronEndpoint {
name: "prewarm-node".into(),
endpoint: mock_url,
}],
models_config: "/dev/null".into(),
};
let fleet = Arc::new(CortexState::from_config(&config));
cortex_gateway::poller::poll_once(&fleet).await;
let nodes = fleet.nodes.read().await;
let node = nodes.get("prewarm-node").unwrap();
assert!(node.healthy);
// /models was empty — no entries in the per-node model map.
assert!(node.models.is_empty());
// But /health's activation should be captured.
let activation = node
.activation
.as_ref()
.expect("activation should be populated after /health poll");
assert_eq!(activation.in_progress.as_deref(), Some("Qwen/model-x"));
assert_eq!(activation.pending, vec!["Qwen/model-y".to_string()]);
}

View File

@@ -1,91 +0,0 @@
//! Integration tests for the `/v1/responses` proxy route.
//!
//! The gateway forwards the request body to whichever neuron has the
//! model loaded. These tests exercise the routing decision (200 on a
//! known model, 404 on an unknown model, 400 on a missing model
//! field) and confirm the response body round-trips verbatim.
mod common;
use serde_json::json;
/// Happy path: gateway routes a `/v1/responses` request to the neuron
/// that has the model loaded, and the neuron's response body
/// arrives at the client unchanged.
#[tokio::test]
async fn test_responses_proxy() {
let mock_url = common::spawn_mock_neuron().await;
let gw_url = common::spawn_gateway(&mock_url).await;
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/responses"))
.header("content-type", "application/json")
.json(&json!({
"model": "test-model",
"input": "Hi"
}))
.send()
.await
.expect("request should succeed");
assert_eq!(resp.status(), 200);
let body: serde_json::Value = resp.json().await.expect("valid JSON response");
assert_eq!(body["id"], "resp-test-001");
assert_eq!(body["object"], "response");
assert_eq!(body["model"], "test-model");
assert_eq!(body["status"], "completed");
assert_eq!(
body["output"][0]["content"][0]["text"],
"Hello from mock backend"
);
// Usage shape is the Responses-specific (input/output_tokens),
// not the chat-completions one (prompt/completion_tokens). Asserts
// the proxy didn't accidentally route through the wrong handler.
assert_eq!(body["usage"]["total_tokens"], 10);
assert!(body["usage"].get("input_tokens").is_some());
}
/// A request that targets a model not present in the catalogue gets
/// 404 from the router. This matches the chat-completions handler's
/// behaviour — same error path, same status code, so a client can
/// share retry logic across the two routes.
#[tokio::test]
async fn test_responses_model_not_found() {
let mock_url = common::spawn_mock_neuron().await;
let gw_url = common::spawn_gateway(&mock_url).await;
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/responses"))
.json(&json!({
"model": "not-in-catalogue",
"input": "Hi"
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 404);
}
/// A request body without a `model` field can't be routed; the
/// gateway returns 400 before reaching a backend. Same as the
/// chat-completions handler — extracted via the same `extract_model`
/// helper.
#[tokio::test]
async fn test_responses_missing_model_field() {
let mock_url = common::spawn_mock_neuron().await;
let gw_url = common::spawn_gateway(&mock_url).await;
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/responses"))
.json(&json!({
"input": "Hi"
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 400);
}

View File

@@ -51,18 +51,18 @@ async fn test_streaming_sse_passthrough() {
}
assert!(
chunks.len() > chunk_count,
"expected more than {} chunks (got {}): {:?}",
chunk_count,
chunks.len() >= chunk_count + 1,
"expected at least {} chunks (got {}): {:?}",
chunk_count + 1,
chunks.len(),
chunks,
);
assert_eq!(chunks.last().unwrap(), "[DONE]");
for (i, chunk) in chunks.iter().enumerate().take(chunk_count) {
for i in 0..chunk_count {
let chunk_json: serde_json::Value =
serde_json::from_str(chunk).expect("chunk should be valid JSON");
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
assert_eq!(
chunk_json["choices"][0]["delta"]["content"],
format!("token{i}")

View File

@@ -1,48 +0,0 @@
[package]
name = "helexa-acp"
version = "0.1.16"
edition = "2024"
license = "Apache-2.0"
repository = "https://git.lair.cafe/helexa/cortex"
description = """
Agent Client Protocol bridge for the helexa self-hosted LLM stack.
Speaks ACP to ACP-compatible editor clients (Zed, etc.) and forwards
the conversation to any OpenAI-compatible HTTP endpoint — defaulting
to cortex (helexa's reverse-proxy / fleet gateway).
"""
# This crate is intentionally self-contained — no dependencies on other
# workspace crates (cortex-core, cortex-gateway, neuron). The goal is
# a painless migration to a dedicated GitHub repo in the future if the
# project grows beyond helexa's needs. All deps are crates.io.
[dependencies]
# `unstable_session_model` flips on the SessionModelState type and the
# session/set_model RPC the model-picker dropdown in Zed needs. The
# feature is upstream-marked unstable; we accept that risk because the
# model picker is core UX and the alternative (rolling our own
# extension method) drifts further from spec each time it moves.
agent-client-protocol = { version = "0.12", features = ["unstable_session_model"] }
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "io-util", "process", "signal"] }
reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"], default-features = false }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
anyhow = "1"
thiserror = "2"
async-trait = "0.1"
futures = "0.3"
tokio-stream = "0.1"
tokio-util = { version = "0.7", features = ["rt"] }
eventsource-stream = "0.2"
async-stream = "0.3"
url = { version = "2", features = ["serde"] }
# Already transitively pulled via the ACP SDK; declared directly so we
# can format ISO 8601 timestamps for `SessionInfo.updated_at` in the
# session/list response.
chrono = { version = "0.4", default-features = false, features = ["std"] }
[[bin]]
name = "helexa-acp"
path = "src/main.rs"

View File

@@ -1,546 +0,0 @@
# helexa-acp
ACP (Agent Client Protocol) bridge for editors like
[Zed](https://zed.dev). Lets you point your editor's agent panel at
**any combination** of OpenAI-compatible, OpenAI Responses, and
Anthropic Messages endpoints — public APIs, private LAN deployments,
local Ollama / LM Studio — and switch between them per session via a
model dropdown.
The "missing ACP binary" for users who don't want to be locked into
one vendor's agent client.
```
┌───────────────────────────────────┐
│ Zed (or any ACP editor client) │
└────────────┬──────────────────────┘
│ stdio JSON-RPC (ACP)
┌─────────────────┐
│ helexa-acp │ ← one binary, multi-endpoint
└─────┬───────────┘
│ HTTP / SSE
┌────────┼─────────────┬──────────────┬──────────────┐
▼ ▼ ▼ ▼ ▼
cortex/ OpenAI Anthropic OpenRouter LM Studio
neuron Responses Messages
(self- (gpt-5,…) (Claude)
hosted)
```
## What it does
- **Speaks ACP** over stdio to editor clients (Zed today; any future
ACP client tomorrow).
- **Multi-endpoint** — one config file lists every LLM endpoint
you want available; pick one per session via the model dropdown
(`endpoint:model` selector).
- **Three wire formats**: `openai-chat` (the broadly compatible
default), `openai-responses` (newer OpenAI surface), and
`anthropic-messages` (Claude). Each is a separate provider impl
in `src/provider/`; adding a fourth (Gemini, Ollama native, …) is
one file plus a `WireApi` enum variant.
- **Built-in tools**: `read_file`, `write_file`, `edit_file`,
`list_dir`, `bash`. Permission-gated by default; the editor user
approves writes/shell per-call.
- **Three session modes**: Default (gated), Bypass Permissions
(auto-allow), and Plan (write-only-to-plan-dir, no shell).
- **Vision** — drag-drop images into the agent panel against any
vision-capable model.
- **Session resume** — multi-day conversations survive editor
restarts via on-disk transcript persistence.
- **Context compaction** — rolling history stays inside the model's
context window automatically so long sessions on small-context
local models don't fall over.
## Install
### From source
```sh
git clone https://git.lair.cafe/helexa/cortex.git
cd cortex
cargo install --path crates/helexa-acp
# Binary lands at ~/.cargo/bin/helexa-acp
```
### Pre-built RPM (Fedora 43)
```sh
dnf copr enable helexa/helexa
dnf install helexa-acp
```
The COPR project bundles helexa-acp alongside the cortex gateway
and helexa-neuron flavours; install only the package(s) you need.
## Quick start
The fastest path: env-var single-endpoint config.
```sh
export HELEXA_ACP_BASE_URL=http://hanzalova.internal:31313/v1
export HELEXA_ACP_MODEL=Qwen/Qwen3.6-27B
helexa-acp # speaks ACP over stdin/stdout; not interactive
```
Then in Zed (`~/.config/zed/settings.json`):
```jsonc
{
"agent_servers": {
"helexa": {
"command": "helexa-acp",
"args": []
}
}
}
```
Restart Zed → open the agent panel → pick "helexa" → start
chatting. Tool calls (file reads, writes, bash) prompt for
permission per-call in Default mode.
That's the minimum. The full config story below is what unlocks
the multi-endpoint dropdown.
## Multi-endpoint config
Copy `helexa-acp.example.toml` from this repo to
`$XDG_CONFIG_HOME/helexa-acp/config.toml` (typically
`~/.config/helexa-acp/config.toml`) and edit:
```toml
default_endpoint = "helexa"
[[endpoints]]
name = "helexa"
base_url = "http://hanzalova.internal:31313/v1"
wire_api = "openai-chat"
default_model = "Qwen/Qwen3.6-27B"
max_tokens = 8192
context_window = 32768
[[endpoints]]
name = "openrouter"
base_url = "https://openrouter.ai/api/v1"
wire_api = "openai-chat"
api_key_env = "OPENROUTER_API_KEY"
default_model = "anthropic/claude-opus-4"
[[endpoints]]
name = "anthropic"
base_url = "https://api.anthropic.com/v1"
wire_api = "anthropic-messages"
api_key_env = "ANTHROPIC_API_KEY"
default_model = "claude-opus-4"
```
Restart Zed. The model dropdown lists every model from every
configured endpoint with the `endpoint:model` selector
(`helexa:Qwen/Qwen3.6-27B`, `openrouter:anthropic/claude-opus-4`,
…). Switch mid-session; the next prompt routes to the new endpoint.
When only one endpoint is configured the prefix is dropped (model
ids appear bare).
### Selector syntax
The `model` field on every internal request is parsed as
`<endpoint>:<model>`:
- `openrouter:gpt-4o` → routes to the `openrouter` endpoint,
model `gpt-4o`.
- `helexa/large` → no colon → falls through to whichever endpoint
is named in `default_endpoint`, model `helexa/large`.
- `:gpt-5` → leading colon → also falls through to default.
## Endpoint cookbook
Copy-pasteable blocks. Mix and match.
### cortex / neuron (self-hosted)
```toml
[[endpoints]]
name = "helexa"
base_url = "http://hanzalova.internal:31313/v1"
wire_api = "openai-chat"
default_model = "Qwen/Qwen3.6-27B"
max_tokens = 8192
context_window = 32768
```
Use `openai-responses` instead of `openai-chat` once cortex 0.1.16+
is deployed and you want the Responses API surface (vision item
shape, structured reasoning items, etc.).
### OpenAI directly
```toml
[[endpoints]]
name = "openai"
base_url = "https://api.openai.com/v1"
wire_api = "openai-responses"
api_key_env = "OPENAI_API_KEY"
default_model = "gpt-5"
```
`openai-responses` is the right choice for current OpenAI models;
`openai-chat` works against legacy GPT-3.5/4 deployments and
anything labelled "chat completions".
### Anthropic directly
```toml
[[endpoints]]
name = "anthropic"
base_url = "https://api.anthropic.com/v1"
wire_api = "anthropic-messages"
api_key_env = "ANTHROPIC_API_KEY"
default_model = "claude-opus-4"
```
helexa-acp sends `x-api-key` + `anthropic-version: 2023-06-01`
automatically. The `api_key_env` indirection keeps your key out of
the config file.
### OpenRouter (multi-vendor proxy)
```toml
[[endpoints]]
name = "openrouter"
base_url = "https://openrouter.ai/api/v1"
wire_api = "openai-chat"
api_key_env = "OPENROUTER_API_KEY"
default_model = "anthropic/claude-opus-4"
```
OpenRouter speaks OpenAI-compat for every model it fronts, so
`openai-chat` is the right wire format regardless of the
underlying vendor.
### LM Studio (local)
```toml
[[endpoints]]
name = "lmstudio"
base_url = "http://localhost:1234/v1"
wire_api = "openai-chat"
default_model = "auto"
```
LM Studio's "auto" model id picks whatever's loaded. Same shape
works for Ollama in compat mode (`http://localhost:11434/v1`) and
vLLM.
### Multiple cortex deployments
```toml
[[endpoints]]
name = "lan"
base_url = "http://hanzalova.internal:31313/v1"
wire_api = "openai-chat"
default_model = "Qwen/Qwen3.6-27B"
[[endpoints]]
name = "cloud"
base_url = "https://cortex.example.com/v1"
wire_api = "openai-chat"
api_key_env = "CLOUD_CORTEX_KEY"
default_model = "Qwen/Qwen3-VL-8B"
```
Use the `endpoint:model` selector to switch between them mid-session.
## Zed setup
`~/.config/zed/settings.json`:
```jsonc
{
"agent_servers": {
"helexa": {
"command": "helexa-acp"
}
}
}
```
Optional environment overrides for the binary:
```jsonc
{
"agent_servers": {
"helexa": {
"command": "helexa-acp",
"env": {
"HELEXA_ACP_LOG_FILE": "/tmp/helexa-acp.log",
"RUST_LOG": "helexa_acp=debug"
}
}
}
}
```
`HELEXA_ACP_LOG_FILE` is the one you actually want — Zed doesn't
surface the agent's stderr, so without that env var debug output is
invisible. Point it at a file you can `tail -f`.
After restarting Zed: ⌘+? (or wherever your "Open Agent Panel"
binding is) → select "helexa" → the model dropdown populates from
your config → start prompting.
## Modes
Three session modes ship; the user picks via Zed's mode dropdown
on the agent panel.
| Mode | Reads | Writes | Bash | Permission prompts |
|------|-------|--------|------|--------------------|
| **Default** | ✓ | with prompt | with prompt | per call |
| **Bypass Permissions** | ✓ | ✓ | ✓ | never |
| **Plan** | ✓ | only into plan dir | disabled | never (plan-dir writes auto-allow) |
### Default
Reads are always allowed (`read_file`, `list_dir` are
unrestricted). Writes and shell commands prompt the user before
running. The intended baseline for any session where the agent
might do something you'd rather review first.
### Bypass Permissions
Auto-allow every tool call. Use for agentic loops you trust — bulk
edits across many files, scripted workflows, prepared session
templates. Never for code the agent hasn't seen before.
### Plan
The "draft an implementation plan before you write code" mode.
Available tools:
- `read_file`, `list_dir`: unrestricted (read the codebase).
- `write_file`, `edit_file`: allowed *only* under
`$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. Any path
outside that returns "plan mode: writes are restricted to …"
back to the model so it self-corrects.
- `bash`: disabled outright. Returns "plan mode: shell execution
is disabled" if attempted.
When the plan is complete, the model presents a 3-option menu:
1. **Bypass Permissions** — implement the plan now, no prompts.
2. **Default** — implement now with per-tool prompts.
3. **Plan** (stay here) — refine the plan with more guidance.
Switch the mode dropdown to your preference and reply to proceed.
## Tools
Five tools, defined in `src/tools.rs`:
| Tool | Args | Gated in Default? |
|------|------|-------------------|
| `read_file` | `path`, `line?`, `limit?` | no |
| `list_dir` | `path` | no |
| `write_file` | `path`, `content` | yes |
| `edit_file` | `path`, `old_text`, `new_text` | yes |
| `bash` | `command`, `cwd?` | yes |
### Path handling
`~`, `~/`, `$HOME`, and `$HOME/` are expanded server-side before
the path reaches ACP or local fs. Lets the model emit
`~/git/repo/file.rs` and have it Just Work.
`read_file` first tries the editor's filesystem (ACP's
`fs/read_text_file` — respects open buffers, workspace overlays,
etc.). If that fails — typically because the path is outside Zed's
workspace boundary — it falls back to `std::fs::read_to_string`.
This lets the agent pull in shared material like
`~/git/architecture/generic.md` from a different project's
session.
The fallback is logged at warn level so you can see when it kicks
in.
### Tool dispatch
Tool descriptions reach the model through a Qwen3 Hermes-format
`# Tools` block injected into the system prompt — cortex/neuron
pass the OpenAI `tools` request field through to the encoder
unread, so we work the model into emitting `<tool_call>{json}</tool_call>`
markers it then parses out of the content stream. This applies to
the helexa wire format; OpenAI / Anthropic endpoints with native
tool support would use their own paths once they're wired in.
The parser is tolerant: malformed JSON (trailing braces, missing
`name`, name nested in `arguments`) gets a repair pass; if that
fails the call surfaces as a "Malformed tool call" card in Zed and
the model gets a synthetic error result so it can self-correct.
## Session resume
helexa-acp persists every session to
`$XDG_DATA_HOME/helexa-acp/sessions/<id>.json`. Zed's `session/list`
RPC asks helexa-acp to enumerate them on workspace open;
`session/load` rehydrates and replays the transcript as
`session/update` notifications so the agent panel renders the
prior conversation.
Behaviour:
- Persisted per-round, so a mid-turn agent stall (long bash, wedged
ACP roundtrip) doesn't lose earlier rounds.
- Survives editor restart and the helexa-acp binary upgrading
between versions.
- Project-scoped: only sessions whose `cwd` matches the workspace
are listed.
To wipe history: `rm -rf $XDG_DATA_HOME/helexa-acp/sessions/`.
## Context compaction
When an endpoint sets `context_window`, helexa-acp projects the
rolling history into a token budget before each request — old
`ToolResult` content (read_file payloads are the worst offenders)
gets elided to one-line markers, preserving `tool_call_id` pairing
so the wire schema stays valid.
System prompts, user turns, and the most recent ~4 messages are
never elided. The full history stays on disk; compaction is a
per-request projection, not a destructive edit.
Set `context_window = 32768` for a 32 K Qwen3, `131072` for a
modern Claude, etc. With `max_tokens` also set, the budget is
`context_window - max_tokens - 512_safety`.
## Troubleshooting
### "default endpoint 'helexa' has no usable provider — check config"
The named default endpoint failed to construct. Usually:
- `api_key_env` references a variable that isn't set in the env
Zed launched helexa-acp with.
- The TOML's `wire_api` is misspelled (only `openai-chat`,
`openai-responses`, `anthropic-messages` are accepted).
Test by running `helexa-acp` directly from a shell — startup
errors land on stderr.
### Model dropdown is empty
Each provider's `list_models` failed at startup. Look at
`HELEXA_ACP_LOG_FILE` for "list_models failed; this endpoint's
models won't appear in the picker". Likely the endpoint URL is
wrong, the API key is invalid, or the upstream `/v1/models`
endpoint isn't responding.
The agent still works against `default_model` even when the
dropdown is empty — list-models is for picking, not routing.
### "prompt_too_long" / agent stalls mid-conversation
You hit the model's context window. Set `context_window` on the
endpoint and helexa-acp will compact before sending. The log line
`context compaction applied` confirms it's running; if it fires
but the upstream still rejects, the compaction heuristic
under-counted and the budget needs tuning down.
### Reading files outside the workspace returns "not found"
Zed's `fs/read_text_file` is workspace-scoped. helexa-acp falls
back to local `std::fs` automatically when that fails — look for
`fs/read_text_file failed; falling back to local std::fs` in the
log. If even local read fails, the file genuinely doesn't exist
or the user process lacks permissions.
### Tool calls render as text instead of structured cards
The model is emitting `<tool_call>` markers that the parser can't
decode. Two common causes:
1. The system prompt isn't reaching the model (cortex/neuron's
tool-block injection didn't fire). Confirm with
`RUST_LOG=helexa_acp=debug` and look at the outgoing
`POST /chat/completions` body.
2. The model itself is too small / undertrained to follow the
Hermes format reliably. helexa-acp has shape-based name
inference and JSON repair, but there's a floor below which
nothing helps.
### Plan-mode writes refused even inside the plan dir
The path comparison is byte-for-byte. If the model emits a path
with `~` and the plan_dir has the expanded form, expansion runs
*before* the comparison — but resolved-vs-symlinked-path
mismatches can still bite. The error message names the attempted
path and the expected prefix so you can compare directly.
## Architecture
Source layout under `crates/helexa-acp/src/`:
| File | Responsibility |
|------|----------------|
| `main.rs` | tokio + Stdio transport. Builds providers, hands off to `agent::Agent` |
| `config.rs` | TOML + env-fallback config, endpoint resolver |
| `agent.rs` | ACP handlers (initialize, session/new, session/prompt, session/cancel, session/set_mode, session/set_model, session/load, session/list), prompt loop with tool-call recursion |
| `session.rs` | Per-session state map (Arc<RwLock<HashMap<…>>>) |
| `store.rs` | On-disk session persistence, plan-dir resolution |
| `prompt.rs` | System-prompt assembly, plan-mode addendum |
| `tools.rs` | Tool schemas + shape-based name inference |
| `tool_runner.rs` | Dispatch a single tool call through ACP client RPCs; permission gate |
| `qwen3.rs` | Qwen3 Hermes tool-format parser (`<tool_call>` / `<think>` markers) |
| `compaction.rs` | Token-budget compaction for the rolling history |
| `path_util.rs` | `~` / `$HOME` expansion shared across every path-taking tool |
| `provider/openai_chat.rs` | OpenAI chat completions provider |
| `provider/openai_responses.rs` | OpenAI Responses API provider |
| `provider/anthropic_messages.rs` | Anthropic Messages API provider |
### Adding a new wire format
1. New file under `src/provider/` implementing the `Provider`
trait (encoder + SSE decoder).
2. Add a `WireApi` variant in `config.rs`.
3. Wire it into `build_provider` in `main.rs`.
4. Done — every other module is wire-format-agnostic.
### Concurrency
- `Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>`
per-session mutex so concurrent requests across sessions don't
contend; the map's RwLock is read-mostly.
- Every tool call dispatched serially within a session (parallel
dispatch would require Zed to handle interleaved permission
prompts).
- Provider streams are back-pressured by the consumer (bounded
mpsc channels).
### Self-contained
The crate has no workspace-internal dependencies (no
`cortex-core`, no `cortex-gateway`). Migration to a dedicated
GitHub repo for cross-platform CI / cargo-dist binaries is
Cargo.toml-only.
## Status
- Stages 16 shipped: scaffold, agent loop, tools, modes, session
resume, image input, model picker, three wire formats.
- Stage 8 (RPM + multi-platform CI) tracked in the canonical plan;
Linux x86_64 RPM ships today via the cortex monorepo's Gitea
Actions.
## Contributing
Repository: https://git.lair.cafe/helexa/cortex (`crates/helexa-acp/`).
Issues / PRs welcome. The canonical staged plan is in
`~/.claude/plans/plan-the-per-device-worker-abstract-micali.md` on
the maintainer's machine; the substages 3a3e and 6a/6b that the
canonical plan didn't anticipate are documented in commit messages.
CI: `cargo fmt --check --all`, `cargo clippy --workspace -- -D
warnings`, `cargo test --workspace` must all pass before merge.

File diff suppressed because it is too large Load Diff

View File

@@ -1,425 +0,0 @@
//! Rolling-conversation compaction for small-context local models.
//!
//! The tool-call loop in [`crate::agent`] grows the message vec it
//! sends upstream every round. On a frontier model that's fine; on a
//! 32 K Qwen3 the first few `read_file` results can push the prompt
//! past the model's context window, at which point cortex/neuron
//! refuses with `prompt_too_long` and the whole turn dies. Long-form
//! local agents are unusable without something here.
//!
//! Strategy (intentionally simple — no LLM-summarization round-trip,
//! no tokenizer dependency):
//!
//! 1. **Protect** the things the model cannot reason without:
//! - The system prompt (idx 0).
//! - Every `Role::User` turn (the user's intent — irreplaceable).
//! - The last [`KEEP_TAIL`] messages (most recent rounds stay
//! verbatim so the model can keep working on what it just
//! observed).
//! 2. **Elide** older `Role::Assistant` prose and older `Role::Tool`
//! result content. The structure stays — `tool_call_id`s, tool
//! names, and argument JSON survive intact — so OpenAI's strict
//! `tool_calls` ↔ `tool` pairing schema remains satisfied. Only
//! the *payload* shrinks to a one-line marker.
//! 3. Walk oldest→newest, recomputing the budget after each elision.
//! Stop as soon as we fit; we don't compact more than necessary.
//! 4. If we still exceed budget after eliding everything we're
//! allowed to, return what we have. The upstream will surface a
//! `prompt_too_long` error and the user can intervene; that's
//! better than silently dropping content the model needs.
//!
//! Token estimation uses a `chars / 3.5` heuristic — conservative
//! (over-estimates tokens slightly) so we compact a touch early
//! rather than a touch late.
use crate::provider::{Message, MessageContent, MessagePart, Role};
/// Most-recent N messages that are never elided. Roughly "the
/// current tool round in flight" — assistant turn that called the
/// tools + each tool result + a bit of slack.
const KEEP_TAIL: usize = 4;
/// Below this content size we don't bother eliding — the savings
/// don't outweigh the loss of detail. Roughly 6080 tokens.
const ELIDE_MIN_CHARS: usize = 256;
/// Roughly tokens-per-character for English + code mixed in. The
/// actual per-tokenizer ratio varies (GPT-4o ≈ 4 chars/token on
/// English prose, ≈ 3 chars/token on code-heavy text). We pick a
/// value on the conservative end so the budget check fires *before*
/// the upstream tokenizer says no.
const CHARS_PER_TOKEN: f32 = 3.5;
/// Per-message envelope overhead (role + JSON framing). Comes out
/// to a few tokens; tiny but it adds up across long histories.
const ENVELOPE_TOKENS: usize = 8;
/// Rough per-image token cost used by the budget estimator. Real
/// vision tokenizers vary widely (2561024 tokens for typical
/// resolutions on Qwen3-VL, OpenAI's `low`/`high` detail toggles
/// pick between ~85 and ~1000+). 512 is a defensible middle that
/// keeps compaction from treating images as free.
const IMAGE_TOKENS_APPROX: usize = 512;
/// Stats reported back from [`compact_to_budget`] for the caller to
/// log. The numbers are estimates (see [`estimate_tokens`]), so
/// don't compare them to upstream-reported token counts as if they
/// were exact.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct CompactionStats {
/// Estimated tokens in the input messages.
pub original_tokens: usize,
/// Estimated tokens after compaction. Equal to `original_tokens`
/// when no compaction was needed.
pub final_tokens: usize,
/// Number of messages whose content was elided. Zero is the
/// hot path (nothing to do).
pub elided_messages: usize,
}
impl CompactionStats {
fn unchanged(tokens: usize) -> Self {
Self {
original_tokens: tokens,
final_tokens: tokens,
elided_messages: 0,
}
}
}
/// Approximate token count for one message. Sums the textual
/// payload's chars, divides by [`CHARS_PER_TOKEN`], and adds an
/// envelope constant. Cheap (no allocation) so safe to call once per
/// message per round.
pub fn estimate_tokens(msg: &Message) -> usize {
let chars = match &msg.content {
MessageContent::Text { text } => text.len(),
MessageContent::MultiPart { parts } => parts
.iter()
.map(|p| match p {
MessagePart::Text { text } => text.len(),
// Each image is one block in the context window; the
// upstream tokenizer handles the real cost (and it
// varies wildly by model — Qwen3-VL uses ~256-1024
// tokens per image depending on size). Take a
// middle estimate so the budget tracker doesn't
// pretend images are free.
MessagePart::Image(_) => IMAGE_TOKENS_APPROX * CHARS_PER_TOKEN as usize,
})
.sum(),
MessageContent::ToolCalls { text, calls } => {
let txt = text.as_deref().map(|s| s.len()).unwrap_or(0);
let calls_size: usize = calls
.iter()
.map(|c| c.name.len() + c.arguments.len() + c.id.len())
.sum();
txt + calls_size
}
MessageContent::ToolResult {
tool_call_id,
content,
} => tool_call_id.len() + content.len(),
};
((chars as f32 / CHARS_PER_TOKEN) as usize) + ENVELOPE_TOKENS
}
/// Sum of [`estimate_tokens`] across all messages.
pub fn total_tokens(messages: &[Message]) -> usize {
messages.iter().map(estimate_tokens).sum()
}
/// Project `messages` into a vec whose estimated token count fits in
/// `budget` tokens. Returns the projection plus stats about what
/// was done. When the input already fits, the projection is a clone
/// of the input and stats report zero elisions.
///
/// See module docs for the strategy and protected set.
pub fn compact_to_budget(messages: &[Message], budget: usize) -> (Vec<Message>, CompactionStats) {
let original = total_tokens(messages);
if original <= budget {
return (messages.to_vec(), CompactionStats::unchanged(original));
}
let mut out = messages.to_vec();
let len = out.len();
let tail_start = len.saturating_sub(KEEP_TAIL);
let mut elided = 0usize;
// Two passes. First pass: ToolResult contents (largest savings
// per elision — read_file payloads land here). Second pass: long
// Assistant prose. We don't interleave because eliding a long
// assistant turn before a really old read_file would do less
// good per elision; oldest-first ordering is enforced *within*
// each pass instead.
for pass in 0..2 {
for i in 1..tail_start {
if matches!(out[i].role, Role::User) {
continue;
}
let target_pass_2 = matches!(
&out[i].content,
MessageContent::Text { .. } | MessageContent::ToolCalls { .. }
);
let target_pass_1 = matches!(&out[i].content, MessageContent::ToolResult { .. });
let in_pass = (pass == 0 && target_pass_1) || (pass == 1 && target_pass_2);
if !in_pass {
continue;
}
if elide_in_place(&mut out[i]) {
elided += 1;
if total_tokens(&out) <= budget {
let final_tokens = total_tokens(&out);
return (
out,
CompactionStats {
original_tokens: original,
final_tokens,
elided_messages: elided,
},
);
}
}
}
}
let final_tokens = total_tokens(&out);
(
out,
CompactionStats {
original_tokens: original,
final_tokens,
elided_messages: elided,
},
)
}
/// Shrink one message's payload while keeping its structural role
/// (so tool_call_id pairing survives). Returns `true` when the
/// message changed.
///
/// - `ToolResult.content` → `(elided: N bytes of tool result)`
/// - `ToolCalls.text` → `(elided: N bytes of assistant prose)`
/// - `Text` (assistant) → `(elided: N bytes of assistant prose)`
///
/// Already-tiny payloads are skipped — eliding a 50-byte string
/// would *grow* it once the marker is in place.
fn elide_in_place(msg: &mut Message) -> bool {
match &mut msg.content {
MessageContent::ToolResult { content, .. } => {
if content.len() < ELIDE_MIN_CHARS {
return false;
}
*content = format!("(elided: {} bytes of tool result)", content.len());
true
}
MessageContent::ToolCalls { text, .. } => match text {
Some(t) if t.len() >= ELIDE_MIN_CHARS => {
*text = Some(format!("(elided: {} bytes of assistant prose)", t.len()));
true
}
_ => false,
},
MessageContent::Text { text } => {
if text.len() < ELIDE_MIN_CHARS {
return false;
}
*text = format!("(elided: {} bytes of assistant prose)", text.len());
true
}
MessageContent::MultiPart { .. } => {
// MultiPart messages today only exist as User turns,
// and User turns are protected by the role check in
// `compact_to_budget` — so this branch is unreachable
// for current call sites. Returning false keeps the
// unreachable path benign if a future stage starts
// emitting MultiPart on other roles.
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::ToolCall;
fn sys(text: &str) -> Message {
Message {
role: Role::System,
content: MessageContent::Text { text: text.into() },
}
}
fn user(text: &str) -> Message {
Message {
role: Role::User,
content: MessageContent::Text { text: text.into() },
}
}
fn assistant_text(text: &str) -> Message {
Message {
role: Role::Assistant,
content: MessageContent::Text { text: text.into() },
}
}
fn assistant_calls(text: Option<&str>, name: &str, args: &str, id: &str) -> Message {
Message {
role: Role::Assistant,
content: MessageContent::ToolCalls {
text: text.map(|s| s.to_string()),
calls: vec![ToolCall {
id: id.into(),
name: name.into(),
arguments: args.into(),
}],
},
}
}
fn tool_result(id: &str, body: &str) -> Message {
Message {
role: Role::Tool,
content: MessageContent::ToolResult {
tool_call_id: id.into(),
content: body.into(),
},
}
}
#[test]
fn under_budget_is_a_no_op_clone() {
let msgs = vec![sys("you are an agent"), user("hi"), assistant_text("hello")];
let (out, stats) = compact_to_budget(&msgs, 10_000);
assert_eq!(stats.elided_messages, 0);
assert_eq!(stats.original_tokens, stats.final_tokens);
assert_eq!(out.len(), msgs.len());
// Strings unchanged.
match &out[2].content {
MessageContent::Text { text } => assert_eq!(text, "hello"),
other => panic!("expected Text, got {other:?}"),
}
}
#[test]
fn elides_old_tool_result_before_old_assistant_prose() {
// History: sys, user, assistant_calls, big_tool_result,
// assistant_with_big_text, user, assistant_calls,
// small_tool_result.
// KEEP_TAIL=4 protects the last four; the big tool result
// sits in the prunable range and should go first because
// pass 0 (tool results) runs before pass 1 (prose).
let big_result = "X".repeat(4096);
let big_prose = "Y".repeat(2048);
let msgs = vec![
sys("preamble"),
user("first ask"),
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "c0"),
tool_result("c0", &big_result),
assistant_text(&big_prose),
user("follow up"),
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "c1"),
tool_result("c1", "short result body"),
];
let before = total_tokens(&msgs);
// Force compaction by setting budget well below current.
let budget = before / 2;
let (out, stats) = compact_to_budget(&msgs, budget);
assert!(
stats.elided_messages >= 1,
"expected at least one elision, got {stats:?}"
);
// The big tool result must be elided (oldest fat target).
match &out[3].content {
MessageContent::ToolResult { content, .. } => {
assert!(
content.starts_with("(elided:"),
"tool result not elided: {content:?}"
);
}
other => panic!("expected ToolResult, got {other:?}"),
}
// Last four messages must be untouched.
assert!(matches!(
&out[out.len() - 1].content,
MessageContent::ToolResult { content, .. } if content == "short result body"
));
}
#[test]
fn never_elides_system_or_user_turns() {
let big_user = "U".repeat(8192);
let msgs = vec![sys("preamble"), user(&big_user), assistant_text("ok")];
let budget = 10; // way below — forces all possible elision
let (out, _stats) = compact_to_budget(&msgs, budget);
// System unchanged.
match &out[0].content {
MessageContent::Text { text } => assert_eq!(text, "preamble"),
other => panic!("expected Text, got {other:?}"),
}
// User unchanged even though it's huge.
match &out[1].content {
MessageContent::Text { text } => assert_eq!(text.len(), big_user.len()),
other => panic!("expected Text, got {other:?}"),
}
}
#[test]
fn preserves_tool_call_id_pairing_after_elision() {
// OpenAI strict mode rejects a tool-result whose tool_call_id
// doesn't match a preceding assistant tool_call. Elision
// must not break that linkage.
let big = "Z".repeat(4096);
let msgs = vec![
sys("preamble"),
user("first"),
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "call_42"),
tool_result("call_42", &big),
// Tail messages.
user("next"),
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "call_43"),
tool_result("call_43", "ok"),
assistant_text("done"),
];
let budget = total_tokens(&msgs) / 3;
let (out, _stats) = compact_to_budget(&msgs, budget);
// The assistant call and its result both carry call_42.
let call_id = match &out[2].content {
MessageContent::ToolCalls { calls, .. } => calls[0].id.clone(),
other => panic!("expected ToolCalls, got {other:?}"),
};
match &out[3].content {
MessageContent::ToolResult { tool_call_id, .. } => {
assert_eq!(tool_call_id, &call_id, "pairing broken");
}
other => panic!("expected ToolResult, got {other:?}"),
}
}
#[test]
fn estimate_tokens_grows_with_content() {
let small = sys("hi");
let large = sys(&"x".repeat(10_000));
assert!(estimate_tokens(&large) > estimate_tokens(&small) * 100);
}
#[test]
fn elide_in_place_skips_short_content() {
let mut m = tool_result("c0", "tiny");
assert!(!elide_in_place(&mut m));
match m.content {
MessageContent::ToolResult { content, .. } => assert_eq!(content, "tiny"),
other => panic!("expected ToolResult, got {other:?}"),
}
}
#[test]
fn returns_best_effort_when_budget_unmeetable() {
// Single huge user message that cannot be elided. Budget 10.
// We don't error — we return what we have and let upstream
// refuse the prompt with its own error.
let big_user = "U".repeat(100_000);
let msgs = vec![sys("preamble"), user(&big_user)];
let (out, stats) = compact_to_budget(&msgs, 10);
assert_eq!(out.len(), msgs.len());
assert!(stats.final_tokens > 10, "still over budget by design");
}
}

View File

@@ -1,424 +0,0 @@
//! Configuration for the helexa-acp bridge.
//!
//! Loaded from `$XDG_CONFIG_HOME/helexa-acp/config.toml` (or
//! `~/.config/helexa-acp/config.toml` as a fallback). If no config file
//! exists, falls back to building a single anonymous endpoint from env
//! vars — that keeps "just point at one cortex" frictionless without
//! requiring a config file on disk.
//!
//! The design goal is "the missing ACP binary for users with multiple
//! API endpoints (possibly on a private LAN, possibly mixing wire
//! types)". Hence: every endpoint is named, has its own wire API, and
//! has its own default model. The agent's selected model id can be
//! prefixed `endpoint:model` to route across endpoints; a bare
//! `model` falls through to the configured `default_endpoint`.
//!
//! ### Example TOML
//!
//! ```toml
//! default_endpoint = "helexa"
//!
//! [[endpoints]]
//! name = "helexa"
//! base_url = "http://hanzalova.internal:31313/v1"
//! wire_api = "openai-chat"
//! default_model = "helexa/large"
//!
//! [[endpoints]]
//! name = "openrouter"
//! base_url = "https://openrouter.ai/api/v1"
//! wire_api = "openai-chat"
//! api_key_env = "OPENROUTER_API_KEY"
//! default_model = "anthropic/claude-opus-4"
//!
//! [[endpoints]]
//! name = "lmstudio"
//! base_url = "http://localhost:1234/v1"
//! wire_api = "openai-chat"
//! default_model = "auto"
//! ```
use anyhow::{Context, anyhow};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use url::Url;
const DEFAULT_BASE_URL: &str = "http://hanzalova.internal:31313/v1";
const DEFAULT_MODEL: &str = "helexa/large";
const DEFAULT_ENDPOINT_NAME: &str = "default";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
/// Name of the endpoint used when a request doesn't pick one
/// explicitly. Must reference an entry in `endpoints`. Defaults to
/// the first endpoint declared if unset.
#[serde(default)]
pub default_endpoint: Option<String>,
/// Per-endpoint configuration. At least one entry is required.
#[serde(default)]
pub endpoints: Vec<EndpointConfig>,
/// Optional path to a system-prompt file. When unset, the built-in
/// default prompt from `prompt.rs` is used.
#[serde(default)]
pub system_prompt_path: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndpointConfig {
/// Short identifier used in `endpoint:model` routing and in logs.
pub name: String,
/// Base URL of the OpenAI-compatible API. Must include the `/v1`
/// (or equivalent) suffix — paths like `chat/completions` and
/// `models` are joined onto this.
pub base_url: Url,
/// Wire protocol the endpoint speaks. Phase 1 supports
/// [`WireApi::OpenAiChat`] only; `openai-responses` and
/// `anthropic-messages` land later behind their own providers.
#[serde(default)]
pub wire_api: WireApi,
/// Model to use when the client hasn't picked one via
/// `session/set_model`.
#[serde(default)]
pub default_model: Option<String>,
/// Static API key to send as `Authorization: Bearer …`. Prefer
/// `api_key_env` for anything sensitive — keys in plain TOML are a
/// liability.
#[serde(default)]
pub api_key: Option<String>,
/// Env var name to read for the API key. Resolved at startup so a
/// missing env var yields a clear error rather than silent
/// unauthenticated calls.
#[serde(default)]
pub api_key_env: Option<String>,
/// Cap on the model's output tokens per turn. `None` lets the
/// upstream pick its own default (cortex/neuron's default is
/// often small enough to trip Zed's "Output Limit Reached" on
/// long responses). Set to e.g. `32768` to let the model
/// produce longer turns. Goes into the OpenAI `max_tokens`
/// request field.
#[serde(default)]
pub max_tokens: Option<u64>,
/// Model context window in tokens (prompt + response). When set,
/// the agent compacts conversation history before each completion
/// so the prompt fits within `context_window - max_tokens - safety`
/// tokens — long sessions on small-context local models (Qwen3 at
/// 32 K) survive past the first few tool-call rounds rather than
/// dying with `prompt_too_long`. `None` disables compaction.
#[serde(default)]
pub context_window: Option<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum WireApi {
/// `POST {base}/chat/completions` returning OpenAI-format SSE.
/// Compatible with cortex, LM Studio, Ollama (compat mode),
/// OpenRouter, OpenAI itself.
#[default]
#[serde(rename = "openai-chat")]
OpenAiChat,
/// `POST {base}/responses` — OpenAI's newer Responses API. Not
/// implemented yet; the variant is reserved so endpoint configs
/// can be authored ahead of provider support.
#[serde(rename = "openai-responses")]
OpenAiResponses,
/// `POST {base}/messages` — Anthropic format. Reserved.
#[serde(rename = "anthropic-messages")]
AnthropicMessages,
}
impl EndpointConfig {
/// Resolve the API key from `api_key` (literal) or `api_key_env`
/// (env-var lookup). Returns `Ok(None)` when neither is set;
/// `Err` when `api_key_env` references a missing variable.
pub fn resolve_api_key(&self) -> anyhow::Result<Option<String>> {
if let Some(literal) = &self.api_key {
return Ok(Some(literal.clone()));
}
if let Some(var) = &self.api_key_env {
return Ok(Some(std::env::var(var).with_context(|| {
format!(
"endpoint '{}' references missing env var {}",
self.name, var
)
})?));
}
Ok(None)
}
/// `{base_url}/chat/completions`.
pub fn chat_completions_url(&self) -> Url {
join_segments(&self.base_url, &["chat", "completions"])
}
/// `{base_url}/responses` — OpenAI Responses API endpoint.
pub fn responses_url(&self) -> Url {
join_segments(&self.base_url, &["responses"])
}
/// `{base_url}/models`. Called from `Provider::list_models`, which
/// Stage 4 wires into the model-picker dropdown; until then it's
/// reachable code with no in-tree callers.
#[allow(dead_code)]
pub fn models_url(&self) -> Url {
join_segments(&self.base_url, &["models"])
}
}
impl Config {
/// Load from TOML at the standard config path, or build from env
/// vars if no file exists. Env-fallback yields a single endpoint
/// named `"default"`.
pub fn load() -> anyhow::Result<Self> {
let path = config_path();
if let Some(path) = &path
&& path.exists()
{
return Self::from_file(path);
}
Self::from_env()
}
/// Single-endpoint config constructed from `HELEXA_ACP_BASE_URL`,
/// `HELEXA_ACP_MODEL`, `HELEXA_ACP_API_KEY`,
/// `HELEXA_ACP_SYSTEM_PROMPT_PATH`, `HELEXA_ACP_MAX_TOKENS`.
pub fn from_env() -> anyhow::Result<Self> {
let base_url = std::env::var("HELEXA_ACP_BASE_URL")
.ok()
.unwrap_or_else(|| DEFAULT_BASE_URL.into());
let base_url = Url::parse(&base_url)
.with_context(|| format!("HELEXA_ACP_BASE_URL is not a valid URL ({base_url})"))?;
let default_model = std::env::var("HELEXA_ACP_MODEL")
.ok()
.unwrap_or_else(|| DEFAULT_MODEL.into());
let api_key = std::env::var("HELEXA_ACP_API_KEY")
.ok()
.filter(|s| !s.is_empty());
let system_prompt_path = std::env::var("HELEXA_ACP_SYSTEM_PROMPT_PATH")
.ok()
.filter(|s| !s.is_empty())
.map(PathBuf::from);
let max_tokens = std::env::var("HELEXA_ACP_MAX_TOKENS")
.ok()
.filter(|s| !s.is_empty())
.map(|s| {
s.parse::<u64>().with_context(|| {
format!("HELEXA_ACP_MAX_TOKENS is not a positive integer ({s})")
})
})
.transpose()?;
let context_window = std::env::var("HELEXA_ACP_CONTEXT_WINDOW")
.ok()
.filter(|s| !s.is_empty())
.map(|s| {
s.parse::<usize>().with_context(|| {
format!("HELEXA_ACP_CONTEXT_WINDOW is not a positive integer ({s})")
})
})
.transpose()?;
Ok(Self {
default_endpoint: Some(DEFAULT_ENDPOINT_NAME.into()),
endpoints: vec![EndpointConfig {
name: DEFAULT_ENDPOINT_NAME.into(),
base_url,
wire_api: WireApi::OpenAiChat,
default_model: Some(default_model),
api_key,
api_key_env: None,
max_tokens,
context_window,
}],
system_prompt_path,
})
}
pub fn from_file(path: &Path) -> anyhow::Result<Self> {
let text = std::fs::read_to_string(path)
.with_context(|| format!("read config {}", path.display()))?;
let mut cfg: Self =
toml::from_str(&text).with_context(|| format!("parse config {}", path.display()))?;
cfg.validate()?;
Ok(cfg)
}
fn validate(&mut self) -> anyhow::Result<()> {
if self.endpoints.is_empty() {
return Err(anyhow!("config has no [[endpoints]] entries"));
}
for (i, ep) in self.endpoints.iter().enumerate() {
if ep.name.is_empty() {
return Err(anyhow!("endpoints[{i}] has empty name"));
}
if ep.name.contains(':') {
return Err(anyhow!(
"endpoints[{i}].name '{}' contains ':' which would clash \
with the endpoint:model selector syntax",
ep.name
));
}
}
// Pick a default endpoint if none was named.
if self.default_endpoint.is_none() {
self.default_endpoint = Some(self.endpoints[0].name.clone());
}
let default_name = self.default_endpoint.as_deref().unwrap();
if !self.endpoints.iter().any(|e| e.name == default_name) {
return Err(anyhow!(
"default_endpoint '{default_name}' is not declared in [[endpoints]]"
));
}
Ok(())
}
/// Look up an endpoint by name. Returns `None` if not configured.
pub fn endpoint(&self, name: &str) -> Option<&EndpointConfig> {
self.endpoints.iter().find(|e| e.name == name)
}
/// The default endpoint (guaranteed to exist after `validate`).
pub fn default_endpoint(&self) -> &EndpointConfig {
let name = self
.default_endpoint
.as_deref()
.expect("default_endpoint set by validate");
self.endpoint(name)
.expect("default_endpoint resolves after validate")
}
}
/// Parse an ACP-side `model` field into (endpoint name, raw model id).
///
/// `helexa:helexa/large` → (`Some("helexa")`, `"helexa/large"`).
/// `helexa/large` → (`None`, `"helexa/large"`).
///
/// The split happens at the FIRST colon. Model ids commonly contain
/// `/` (HuggingFace style) but rarely `:`; if a model id ever does, the
/// user can quote-prefix with the default endpoint name.
pub fn parse_model_selector(input: &str) -> (Option<&str>, &str) {
match input.split_once(':') {
Some((endpoint, model)) if !endpoint.is_empty() && !model.is_empty() => {
(Some(endpoint), model)
}
_ => (None, input),
}
}
fn config_path() -> Option<PathBuf> {
if let Ok(override_path) = std::env::var("HELEXA_ACP_CONFIG_PATH") {
return Some(PathBuf::from(override_path));
}
let xdg = std::env::var("XDG_CONFIG_HOME")
.ok()
.filter(|s| !s.is_empty());
let base = xdg.map(PathBuf::from).or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| PathBuf::from(h).join(".config"))
})?;
Some(base.join("helexa-acp").join("config.toml"))
}
fn join_segments(base: &Url, segments: &[&str]) -> Url {
let mut out = base.clone();
if let Ok(mut path) = out.path_segments_mut() {
path.pop_if_empty().extend(segments.iter().copied());
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn url_join_handles_trailing_slash() {
let ep = EndpointConfig {
name: "x".into(),
base_url: Url::parse("http://h.internal:31313/v1").unwrap(),
wire_api: WireApi::OpenAiChat,
default_model: None,
api_key: None,
api_key_env: None,
max_tokens: None,
context_window: None,
};
assert_eq!(
ep.chat_completions_url().as_str(),
"http://h.internal:31313/v1/chat/completions"
);
assert_eq!(
ep.models_url().as_str(),
"http://h.internal:31313/v1/models"
);
}
#[test]
fn parses_model_selector() {
assert_eq!(
parse_model_selector("helexa:helexa/large"),
(Some("helexa"), "helexa/large")
);
assert_eq!(parse_model_selector("helexa/large"), (None, "helexa/large"));
assert_eq!(parse_model_selector("gpt-5"), (None, "gpt-5"));
// Edge case: a leading colon → no endpoint.
assert_eq!(parse_model_selector(":gpt-5"), (None, ":gpt-5"));
}
#[test]
fn env_fallback_builds_single_endpoint() {
// Don't actually set env vars (would race with other tests);
// just confirm the default path constructs cleanly.
unsafe {
std::env::remove_var("HELEXA_ACP_BASE_URL");
std::env::remove_var("HELEXA_ACP_MODEL");
std::env::remove_var("HELEXA_ACP_API_KEY");
}
let cfg = Config::from_env().unwrap();
assert_eq!(cfg.endpoints.len(), 1);
assert_eq!(cfg.endpoints[0].name, "default");
assert_eq!(cfg.endpoints[0].base_url.as_str(), DEFAULT_BASE_URL);
assert_eq!(
cfg.endpoints[0].default_model.as_deref(),
Some(DEFAULT_MODEL)
);
}
#[test]
fn toml_parses_multi_endpoint() {
let toml_text = r#"
default_endpoint = "helexa"
[[endpoints]]
name = "helexa"
base_url = "http://hanzalova.internal:31313/v1"
default_model = "helexa/large"
[[endpoints]]
name = "openrouter"
base_url = "https://openrouter.ai/api/v1"
wire_api = "openai-chat"
api_key_env = "OPENROUTER_API_KEY"
default_model = "anthropic/claude-opus-4"
"#;
let mut cfg: Config = toml::from_str(toml_text).unwrap();
cfg.validate().unwrap();
assert_eq!(cfg.endpoints.len(), 2);
assert_eq!(cfg.default_endpoint().name, "helexa");
assert_eq!(cfg.endpoints[0].wire_api, WireApi::OpenAiChat);
assert_eq!(
cfg.endpoints[1].api_key_env.as_deref(),
Some("OPENROUTER_API_KEY")
);
}
#[test]
fn validate_rejects_colon_in_endpoint_name() {
let toml_text = r#"
[[endpoints]]
name = "bad:name"
base_url = "http://x/v1"
"#;
let mut cfg: Config = toml::from_str(toml_text).unwrap();
let err = cfg.validate().unwrap_err();
assert!(format!("{err}").contains("clash"));
}
}

View File

@@ -1,145 +0,0 @@
//! helexa-acp — Agent Client Protocol bridge for multi-endpoint LLM
//! setups (helexa, LM Studio, Ollama, OpenRouter, OpenAI, Anthropic,
//! …) with a clean per-endpoint wire-format selector.
//!
//! Speaks ACP over stdio to an editor client (Zed today). Every
//! configured endpoint produces a wire-format-specific
//! [`provider::Provider`] implementation; the agent loop in
//! [`agent::Agent`] is provider-agnostic, so adding e.g. an Anthropic
//! /v1/messages provider doesn't touch `agent.rs`.
//!
//! Config: `$XDG_CONFIG_HOME/helexa-acp/config.toml` for the multi-
//! endpoint case; env vars (`HELEXA_ACP_BASE_URL`, etc.) for the
//! single-endpoint case when no config file exists.
use agent_client_protocol::{Result, Stdio};
use std::sync::Arc;
mod agent;
mod compaction;
mod config;
mod path_util;
mod prompt;
mod provider;
mod qwen3;
mod session;
mod store;
mod tool_runner;
mod tools;
use agent::Agent;
use config::{Config, EndpointConfig, WireApi};
use provider::{
Provider, anthropic_messages::AnthropicMessagesProvider, openai_chat::OpenAIChatProvider,
openai_responses::OpenAIResponsesProvider,
};
/// Set up tracing. Logs go to stderr by default — stdout is
/// reserved for the JSON-RPC stream. Setting `HELEXA_ACP_LOG_FILE`
/// to an absolute path appends logs to that file instead, which is
/// the practical way to capture debug output when the agent runs
/// under an editor (Zed, etc.) that doesn't surface stderr.
///
/// `RUST_LOG` still controls levels (e.g. `helexa_acp=debug`).
/// ANSI colours are auto-stripped when writing to a file so the log
/// is plain text.
fn init_tracing() {
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"));
let log_file = std::env::var("HELEXA_ACP_LOG_FILE")
.ok()
.filter(|s| !s.is_empty());
match log_file {
Some(path) => match std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)
{
Ok(file) => {
tracing_subscriber::fmt()
.with_writer(std::sync::Mutex::new(file))
.with_env_filter(env_filter)
.with_ansi(false)
.init();
}
Err(e) => {
// Fall back to stderr and shout. We don't want a
// typo'd log path to silence the agent entirely.
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(env_filter)
.init();
tracing::warn!(
path = %path,
error = %e,
"HELEXA_ACP_LOG_FILE could not be opened; using stderr"
);
}
},
None => {
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(env_filter)
.init();
}
}
}
/// Build a provider for `endpoint` according to its declared
/// `wire_api`. Future wire types (OpenAI Responses, Anthropic
/// /v1/messages, Ollama native) slot in here without changing the
/// caller.
fn build_provider(endpoint: EndpointConfig) -> anyhow::Result<Arc<dyn Provider>> {
match endpoint.wire_api {
WireApi::OpenAiChat => Ok(Arc::new(OpenAIChatProvider::new(endpoint)?)),
WireApi::OpenAiResponses => Ok(Arc::new(OpenAIResponsesProvider::new(endpoint)?)),
WireApi::AnthropicMessages => Ok(Arc::new(AnthropicMessagesProvider::new(endpoint)?)),
}
}
#[tokio::main]
async fn main() -> Result<()> {
init_tracing();
let cfg = Config::load()
.map_err(|e| agent_client_protocol::util::internal_error(format!("config: {e:#}")))?;
tracing::info!(
endpoints = cfg.endpoints.len(),
default_endpoint = %cfg.default_endpoint().name,
default_model = ?cfg.default_endpoint().default_model,
"helexa-acp starting"
);
// Build a provider for each configured endpoint up-front. Cheap —
// just sets up a reqwest::Client and resolves the API key — and
// surfaces config mistakes (missing API key env var, unsupported
// wire_api) before the editor even sends an initialize request.
let mut providers: Vec<Arc<dyn Provider>> = Vec::with_capacity(cfg.endpoints.len());
for endpoint in &cfg.endpoints {
match build_provider(endpoint.clone()) {
Ok(p) => {
tracing::info!(
endpoint = %endpoint.name,
base_url = %endpoint.base_url,
wire_api = ?endpoint.wire_api,
"registered provider"
);
providers.push(p);
}
Err(e) => {
tracing::warn!(
endpoint = %endpoint.name,
error = %format!("{e:#}"),
"skipping endpoint with invalid config"
);
}
}
}
let agent = Agent::new(&cfg, providers)
.await
.map_err(|e| agent_client_protocol::util::internal_error(format!("agent: {e:#}")))?;
agent.serve(Stdio::new()).await
}

View File

@@ -1,192 +0,0 @@
//! Path expansion shared across every tool that takes a path.
//!
//! Models often emit shell-style paths like `~/git/repo/file.rs` or
//! `$HOME/notes.md`. ACP's `fs/read_text_file` and friends — and our
//! own local `std::fs` reads — both want a real absolute path; the
//! `~` / `$HOME` forms reach them as literal strings and the open
//! fails. The tool schemas already document "absolute path" but in
//! practice the model slips up often enough that handling it
//! server-side is the difference between "works" and "the agent is
//! brittle".
//!
//! Scope is deliberately small:
//!
//! - `~` and `~/` (current user only — `~user` lookups would require
//! pulling in passwd parsing).
//! - `$HOME` and `$HOME/`.
//!
//! Any other shell variable (`$PWD`, `${HOME}`, …) passes through
//! unchanged. The shell already expands them inside `bash` tool
//! commands; for the file-tool argument fields, we deliberately
//! limit the set so the behaviour is predictable.
//!
//! Falls back to the input path verbatim when `HOME` is unset
//! (stripped-down container env). That preserves the "no surprise
//! mutations" rule — never invent a path the caller didn't ask for.
use std::path::{Path, PathBuf};
/// Process-global lock for tests that mutate `HOME`. Anyone in the
/// crate touching `HOME` must hold this for the duration of the
/// read-modify-restore window — otherwise concurrent `cargo test`
/// workers race and flake.
///
/// Only built into the test binaries. Production code never mutates
/// env vars.
#[cfg(test)]
pub(crate) static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
/// Expand `~`, `~/`, `$HOME`, and `$HOME/` prefixes against the
/// current user's home directory. All other inputs pass through
/// unchanged.
///
/// Returns the input verbatim if `HOME` isn't set in the env.
pub fn expand_path(input: &Path) -> PathBuf {
let Some(s) = input.to_str() else {
return input.to_path_buf();
};
let Ok(home) = std::env::var("HOME") else {
return input.to_path_buf();
};
let home = PathBuf::from(home);
if s == "~" || s == "$HOME" {
return home;
}
if let Some(rest) = s.strip_prefix("~/") {
return home.join(rest);
}
if let Some(rest) = s.strip_prefix("$HOME/") {
return home.join(rest);
}
input.to_path_buf()
}
#[cfg(test)]
mod tests {
use super::*;
/// Set HOME for the duration of the test. Tests using this run
/// serially under the crate-wide [`ENV_LOCK`] because env
/// mutation isn't thread-safe — `cargo test` parallel workers
/// would race without it.
fn with_home<F: FnOnce()>(home: &str, body: F) {
let _g = ENV_LOCK.lock().unwrap();
let prior = std::env::var("HOME").ok();
// SAFETY: tests touch process-global env. The mutex
// serialises access; sub-threads in other test modules
// touching HOME aren't expected (none in this crate).
unsafe {
std::env::set_var("HOME", home);
}
body();
unsafe {
match prior {
Some(p) => std::env::set_var("HOME", p),
None => std::env::remove_var("HOME"),
}
}
}
#[test]
fn expands_tilde_slash() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("~/git/repo/file.rs")),
PathBuf::from("/home/me/git/repo/file.rs")
);
});
}
#[test]
fn expands_bare_tilde() {
with_home("/home/me", || {
assert_eq!(expand_path(Path::new("~")), PathBuf::from("/home/me"));
});
}
#[test]
fn expands_dollar_home_slash() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("$HOME/notes.md")),
PathBuf::from("/home/me/notes.md")
);
});
}
#[test]
fn expands_bare_dollar_home() {
with_home("/home/me", || {
assert_eq!(expand_path(Path::new("$HOME")), PathBuf::from("/home/me"));
});
}
#[test]
fn absolute_path_passes_through() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("/etc/hostname")),
PathBuf::from("/etc/hostname")
);
});
}
#[test]
fn relative_path_passes_through() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("src/main.rs")),
PathBuf::from("src/main.rs")
);
});
}
#[test]
fn tilde_user_form_not_expanded() {
// ~other is shell sugar for /home/other and would require
// passwd parsing to resolve. Out of scope — pass it
// through and let the open fail with a clear error.
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("~other/x")),
PathBuf::from("~other/x")
);
});
}
#[test]
fn no_home_env_passes_through() {
// Share the same crate-wide lock as `with_home` — otherwise
// a parallel test setting HOME races this clear-and-assert
// window.
let _g = ENV_LOCK.lock().unwrap();
let prior = std::env::var("HOME").ok();
// SAFETY: serialised by LOCK above.
unsafe {
std::env::remove_var("HOME");
}
assert_eq!(
expand_path(Path::new("~/git/repo")),
PathBuf::from("~/git/repo")
);
unsafe {
if let Some(p) = prior {
std::env::set_var("HOME", p);
}
}
}
#[test]
fn dollar_other_var_not_expanded() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("$PWD/file")),
PathBuf::from("$PWD/file")
);
assert_eq!(
expand_path(Path::new("${HOME}/file")),
PathBuf::from("${HOME}/file")
);
});
}
}

View File

@@ -1,274 +0,0 @@
//! System prompt assembly.
//!
//! The system message has two parts:
//!
//! 1. A short human-readable preamble (working directory, style
//! instructions). Either the built-in [`DEFAULT_PROMPT`] or a
//! user-supplied file at `HELEXA_ACP_SYSTEM_PROMPT_PATH` /
//! `system_prompt_path`. `{cwd}` is substituted in both.
//! 2. A `# Tools` block in Qwen3 Hermes format (see [`crate::qwen3`])
//! describing the available functions. This is what makes the
//! model actually call them — neuron/cortex don't honour the
//! OpenAI `tools` API field, so the tool list has to live in the
//! prompt itself.
use agent_client_protocol::schema::SessionModeId;
use anyhow::Context;
use std::path::Path;
use crate::provider::ToolSpec;
use crate::qwen3;
use crate::session::MODE_PLAN;
const DEFAULT_PROMPT: &str = "\
You are helexa-acp, a coding assistant working inside an editor.
Working directory: {cwd}
Use the tools described below whenever the user's request involves
looking at or modifying files, or running commands. Do not ask the
user to paste file contents you could read yourself. All file paths
must be absolute. Writes and shell commands may prompt the user for
permission depending on the session mode.
Be concise; the user is reading your output in an editor pane.";
/// Build the system prompt for a session.
///
/// - `cwd`: session working directory (substituted for `{cwd}` in
/// the preamble — both the default and any user-supplied template).
/// - `override_path`: path to a user-supplied template, already
/// resolved by [`crate::config::Config`]. The `# Tools` block is
/// appended *after* the user's template so a custom preamble
/// still gets the tool descriptions the model needs.
/// - `tools`: the tools to advertise. Empty list → no `# Tools`
/// block is appended at all.
/// - `mode`: current session mode. When the mode is [`MODE_PLAN`]
/// a plan-mode addendum describing the restrictions and the
/// completion menu is appended *after* the `# Tools` block so it
/// is the last thing the model reads before user input.
/// - `plan_dir`: resolved plan directory for the cwd. Only consulted
/// when `mode == MODE_PLAN`. `None` means the plan directory could
/// not be resolved (no `HOME` / `XDG_DATA_HOME`) — the addendum
/// still renders but with a placeholder so the model knows to
/// surface the error to the user rather than guess a path.
pub fn build_system_prompt(
cwd: &Path,
override_path: Option<&Path>,
tools: &[ToolSpec],
mode: &SessionModeId,
plan_dir: Option<&Path>,
) -> anyhow::Result<String> {
let template = match override_path {
Some(path) => std::fs::read_to_string(path)
.with_context(|| format!("read system prompt from {}", path.display()))?,
None => DEFAULT_PROMPT.to_string(),
};
let mut prompt = template.replace("{cwd}", &cwd.display().to_string());
prompt.push_str(&qwen3::render_tool_block(tools));
if mode.0.as_ref() == MODE_PLAN {
prompt.push_str(&render_plan_mode_block(plan_dir));
}
Ok(prompt)
}
/// Plan-mode instruction block. Tells the model:
///
/// 1. Where it may write — only inside `plan_dir`.
/// 2. What it may *not* do — bash is disabled; writes outside
/// `plan_dir` are refused by the runtime.
/// 3. How to finish — emit the 3-option menu so the user can
/// switch modes and either kick off implementation (with or
/// without permission prompts) or keep iterating on the plan.
fn render_plan_mode_block(plan_dir: Option<&Path>) -> String {
let plan_path = plan_dir
.map(|p| p.display().to_string())
.unwrap_or_else(|| "<plan directory could not be resolved — tell the user>".to_string());
format!(
"\n\n# Plan mode\n\
\n\
You are in **plan mode**. Your task is to draft a written\n\
implementation plan for the user; you must NOT modify any\n\
project files or run shell commands.\n\
\n\
Rules in plan mode:\n\
\n\
- `read_file` and `list_dir` are unrestricted — use them to\n\
explore the codebase as needed.\n\
- `write_file` and `edit_file` are allowed ONLY under the\n\
plan directory: `{plan_path}`. The runtime will refuse any\n\
write outside it.\n\
- `bash` is disabled. Do not call it.\n\
\n\
Write the plan as one or more Markdown files under\n\
`{plan_path}`. Use descriptive filenames\n\
(`01-overview.md`, `02-data-model.md`, etc.). It is fine to\n\
iterate — overwrite the file when you refine a section.\n\
\n\
When the plan is complete, do NOT begin implementation.\n\
Instead, end your turn with this menu, verbatim, so the\n\
user can choose how to proceed:\n\
\n\
---\n\
**Plan complete.** To proceed, switch the session mode in\n\
the agent dropdown and send a follow-up message:\n\
\n\
1. **Bypass Permissions** — implement the plan now, skipping\n\
per-tool permission prompts.\n\
2. **Default** — implement the plan now, prompting before\n\
each write or shell command.\n\
3. **Plan** (stay here) — refine the plan; reply with the\n\
change you want and I will revise it.\n\
---\n"
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::{MODE_DEFAULT, MODE_PLAN};
use std::io::Write;
fn default_mode() -> SessionModeId {
SessionModeId::new(MODE_DEFAULT)
}
fn plan_mode() -> SessionModeId {
SessionModeId::new(MODE_PLAN)
}
#[test]
fn default_prompt_substitutes_cwd() {
let prompt =
build_system_prompt(Path::new("/home/me/proj"), None, &[], &default_mode(), None)
.unwrap();
assert!(
prompt.contains("/home/me/proj"),
"cwd not interpolated: {prompt}"
);
assert!(prompt.contains("helexa-acp"));
assert!(
!prompt.contains("{cwd}"),
"left-over placeholder in default prompt"
);
// With no tools, the # Tools block is absent.
assert!(!prompt.contains("# Tools"));
// Default mode does not get the plan-mode addendum.
assert!(!prompt.contains("# Plan mode"));
}
#[test]
fn tools_are_appended_in_hermes_format() {
let spec = ToolSpec {
name: "read_file".into(),
description: "Read a file.".into(),
parameters: serde_json::json!({"type":"object","properties":{}, "required":[]}),
};
let prompt =
build_system_prompt(Path::new("/x"), None, &[spec], &default_mode(), None).unwrap();
assert!(prompt.contains("# Tools"));
assert!(prompt.contains("<tools>"));
assert!(prompt.contains("\"name\":\"read_file\""));
assert!(prompt.contains("<tool_call>"));
}
#[test]
fn override_path_is_read_and_templated() {
let mut tmp = tempfile_in_target("prompt.txt");
tmp.write_all(b"custom prompt for {cwd} only").unwrap();
tmp.flush().unwrap();
let path = tmp.path().to_path_buf();
drop(tmp);
let prompt = build_system_prompt(
Path::new("/etc"),
Some(path.as_path()),
&[],
&default_mode(),
None,
)
.expect("read override");
assert_eq!(prompt, "custom prompt for /etc only");
let _ = std::fs::remove_file(&path);
}
#[test]
fn missing_override_path_errors() {
let err = build_system_prompt(
Path::new("/tmp"),
Some(Path::new("/definitely/not/a/real/path")),
&[],
&default_mode(),
None,
)
.unwrap_err();
assert!(format!("{err:#}").contains("read system prompt"));
}
#[test]
fn plan_mode_addendum_includes_plan_dir_and_menu() {
let plan_dir = Path::new("/home/me/.local/share/helexa-acp/plans/proj-deadbeef");
let prompt = build_system_prompt(
Path::new("/home/me/proj"),
None,
&[],
&plan_mode(),
Some(plan_dir),
)
.unwrap();
assert!(prompt.contains("# Plan mode"));
assert!(
prompt.contains(plan_dir.to_str().unwrap()),
"plan dir not interpolated: {prompt}"
);
// The 3-option menu must be present so the model emits it verbatim.
assert!(prompt.contains("Bypass Permissions"));
assert!(prompt.contains("**Default**"));
assert!(prompt.contains("3. **Plan**"));
// Bash disabled instruction must be present.
assert!(prompt.contains("`bash` is disabled"));
}
#[test]
fn plan_mode_addendum_handles_unresolved_plan_dir() {
let prompt =
build_system_prompt(Path::new("/home/me/proj"), None, &[], &plan_mode(), None).unwrap();
assert!(prompt.contains("# Plan mode"));
assert!(prompt.contains("could not be resolved"));
}
/// Tiny temp-file helper that doesn't pull in the `tempfile` crate.
/// Writes under `target/` so it's cleaned up by `cargo clean`.
fn tempfile_in_target(name: &str) -> TempHandle {
let base = std::env::var("CARGO_TARGET_TMPDIR")
.ok()
.map(std::path::PathBuf::from)
.unwrap_or_else(std::env::temp_dir);
let _ = std::fs::create_dir_all(&base);
let pid = std::process::id();
let path = base.join(format!("helexa-acp-{pid}-{name}"));
let file = std::fs::File::create(&path).expect("create temp file");
TempHandle { file, path }
}
struct TempHandle {
file: std::fs::File,
path: std::path::PathBuf,
}
impl TempHandle {
fn path(&self) -> &Path {
&self.path
}
}
impl Write for TempHandle {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.file.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.file.flush()
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,230 +0,0 @@
//! Provider trait — the seam between the ACP-side agent loop and
//! whatever wire protocol an endpoint actually speaks.
//!
//! Every concrete provider (OpenAI chat completions, OpenAI Responses,
//! Anthropic /v1/messages, Ollama native, …) implements
//! [`Provider`]. The agent constructs a [`CompletionRequest`] using
//! provider-agnostic types and consumes a stream of
//! [`CompletionEvent`]s — neither end knows which wire format is on
//! the other side of the trait.
//!
//! Day-1 provider: [`openai_chat::OpenAIChatProvider`]. Day-N
//! providers slot in without touching `agent.rs`.
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
pub mod anthropic_messages;
pub mod openai_chat;
pub mod openai_responses;
/// Provider-agnostic LLM endpoint. Implementations translate between
/// [`CompletionRequest`] / [`CompletionEvent`] and whatever wire
/// format their endpoint speaks.
#[async_trait]
pub trait Provider: Send + Sync {
/// Endpoint name as configured by the user (e.g. `"helexa"`,
/// `"openrouter"`). Used in logs and in the `endpoint:model`
/// selector.
fn name(&self) -> &str;
/// List models available at this endpoint. Used to build the
/// model-picker dropdown in editor clients (Stage 4). Should
/// return quickly (cache if necessary).
#[allow(dead_code)]
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>>;
/// Run a chat completion. Returns a stream of provider-agnostic
/// events. The stream stops when the upstream finishes, when
/// `cancel` is fired, or when the stream is dropped.
async fn complete(
&self,
request: CompletionRequest,
cancel: CancellationToken,
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>>;
}
/// One model exposed by a provider. Constructed by `list_models` —
/// Stage 4 is when the agent loop starts consuming it for the
/// model-picker dropdown.
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
/// Human-friendly name, if the endpoint exposes one. Otherwise
/// `id` is used as the display name.
#[serde(default)]
pub display_name: Option<String>,
}
/// Inputs to a completion. Provider-agnostic — concrete providers
/// translate this into their wire format.
#[derive(Debug, Clone)]
pub struct CompletionRequest {
/// Endpoint-local model id (without the `endpoint:` prefix).
pub model: String,
pub messages: Vec<Message>,
/// Tools the model is allowed to call. Empty list means no tool
/// support advertised.
pub tools: Vec<ToolSpec>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub max_tokens: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Role {
System,
User,
Assistant,
/// Tool result message. Provider impls turn this into whatever
/// shape the upstream wire format wants (OpenAI uses
/// `role: "tool"` + `tool_call_id`; Anthropic uses content blocks).
/// Stage 3 (tools) constructs this; Stage 2 never does.
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessageContent {
/// Plain text turn (system / user / assistant). Struct variant
/// rather than newtype so the persisted JSON has an explicit
/// `text` field — that lets us use internal tagging on the
/// enum, which is incompatible with newtype-of-primitive
/// variants.
Text { text: String },
/// Mixed text + image user turn. Stage 5 introduces this when
/// Zed sends an `ImageContent` block alongside the user's prompt.
/// Providers that don't support vision should down-convert by
/// dropping image parts and concatenating text parts.
MultiPart { parts: Vec<MessagePart> },
/// Assistant turn that called one or more tools. Stage 3 starts
/// constructing this when the provider stream yields a
/// `ToolCallStart` / `ToolCallArgsDelta` sequence.
ToolCalls {
/// Optional text the assistant said alongside the tool calls.
text: Option<String>,
calls: Vec<ToolCall>,
},
/// Tool result. `tool_call_id` matches the assistant's call id.
/// Stage 3 constructs this after the tool runner finishes.
ToolResult {
tool_call_id: String,
content: String,
},
}
/// One part of a [`MessageContent::MultiPart`] message.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessagePart {
Text { text: String },
Image(ImageData),
}
/// Inline image attachment. `data` is base64-encoded raw image
/// bytes; the encoder constructs an `image_url` data URI from it
/// at request time. `uri` carries any pointer the client supplied
/// (e.g. `file:///tmp/x.png`) — we keep it on the message for
/// debugging / future providers but the OpenAI encoder ignores it
/// when `data` is present (data wins, since it round-trips through
/// every wire format).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageData {
pub mime_type: String,
/// Base64-encoded image bytes (no `data:` prefix, no padding
/// stripped — exactly what `ImageContent.data` carried).
pub data: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub uri: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
/// Provider-assigned id that ties the call to its result. The
/// Qwen3 wire format we use today doesn't carry this on the
/// model side (calls and results are matched positionally inside
/// a turn), so the field looks unused in the prod build — but it
/// flows through to `MessageContent::ToolResult.tool_call_id` for
/// history bookkeeping and a future strict-OpenAI backend will
/// consume it directly.
#[allow(dead_code)]
pub id: String,
pub name: String,
/// JSON-encoded arguments. Kept as a string because providers
/// stream argument bytes incrementally and only validate at the
/// end; the agent decodes once the call is complete.
pub arguments: String,
}
#[derive(Debug, Clone)]
pub struct ToolSpec {
pub name: String,
pub description: String,
/// JSON Schema of the arguments object.
pub parameters: Value,
}
/// Events emitted by a provider during a streaming completion.
#[derive(Debug, Clone)]
pub enum CompletionEvent {
/// Incremental visible text from the assistant.
TextDelta(String),
/// Incremental "reasoning" / thought text, if the model emits one
/// (e.g. Qwen3 with `<think>` tags surfaced as a separate stream,
/// or OpenAI reasoning models).
ReasoningDelta(String),
/// A new tool call has started. Stage 2 ignores the payload; the
/// agent loop in Stage 3 reads `index` to correlate with
/// [`Self::ToolCallArgsDelta`], `id` for the eventual tool-result
/// turn, and `name` to dispatch the runner.
#[allow(dead_code)]
ToolCallStart {
index: usize,
id: String,
name: String,
},
/// More argument bytes for a tool call already announced via
/// [`Self::ToolCallStart`]. Stage 2 ignores; Stage 3 accumulates
/// the bytes by `index` until the call's arguments are complete.
#[allow(dead_code)]
ToolCallArgsDelta { index: usize, args_delta: String },
/// A `<tool_call>` block whose JSON couldn't be parsed even with
/// the qwen3 module's repair attempts. The agent surfaces this
/// as a Failed `SessionUpdate::ToolCall` card with the raw body
/// visible (so the editor renders structured failure UI rather
/// than dumping the body inline in the message pane), and feeds
/// a synthetic tool-error message back into history so the
/// model can self-correct on the next round.
MalformedToolCall { raw: String },
/// Stream finished. Carries the upstream `finish_reason` if it
/// gave one (`"stop"`, `"length"`, `"tool_calls"`, …).
Finish { reason: Option<String> },
/// Final usage stats, if the provider supplied them. Stage 2
/// matches the variant to drop it; Stage 6b (token metrics) is
/// when the payload starts being read.
#[allow(dead_code)]
Usage(UsageStats),
}
/// Token accounting reported by the provider at the end of a stream.
/// Stage 2 doesn't surface usage anywhere — the stable `PromptResponse`
/// has no usage field, and the unstable variant is gated. Stage 6b
/// turns these on with Prometheus metrics.
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, Default)]
pub struct UsageStats {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,987 +0,0 @@
//! OpenAI Responses API (`POST /v1/responses`) provider.
//!
//! Mirror image of [`super::openai_chat`]: same `Provider` trait
//! impl, same back-pressured SSE decoder, but speaking OpenAI's
//! newer Responses surface instead of chat completions.
//!
//! Differences from the chat provider, all contained in this file:
//!
//! - **Request encoding**: history flattens into an `input` array
//! of typed items (`message`, `function_call`, `function_call_output`)
//! plus a top-level `instructions` field for the system prompt.
//! Multi-part user content stays in the same `[{type:"input_text"},
//! {type:"input_image"}]` shape neuron's `request_to_chat` already
//! accepts.
//! - **Streaming decoder**: events are named (`response.created`,
//! `response.output_text.delta`, `response.completed`, …) carried
//! on the SSE `event:` line. The chat path's `[DONE]` terminator
//! doesn't apply; the stream ends after `response.completed`.
//! - **Tool calls** plumb through the `response.output_item.added`
//! (item type `function_call`) → `response.function_call_arguments.delta`
//! → `response.function_call_arguments.done` event sequence. The
//! neuron candle harness doesn't synthesize these yet (tracked as
//! issue #6), but the decoder is wired so the day the upstream
//! does, downstream `CompletionEvent::ToolCall*` plumbing just
//! works.
//!
//! Tool-name handling: the model knows its tool descriptions via
//! the [`crate::qwen3`] system-prompt block exactly the way the chat
//! provider does. We don't echo them in the request body because
//! neuron currently ignores `tools` on /v1/responses (same as on
//! /v1/chat/completions). Once neuron honours request-side tool
//! definitions, both providers add them in the same place.
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures::{Stream, StreamExt, stream::BoxStream};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::HashMap;
use tokio_util::sync::CancellationToken;
use super::{
CompletionEvent, CompletionRequest, Message, MessageContent, MessagePart, ModelInfo, Provider,
Role, UsageStats,
};
use crate::config::EndpointConfig;
pub struct OpenAIResponsesProvider {
endpoint: EndpointConfig,
#[allow(dead_code)] // Read in `complete()`'s HTTP path; tests don't stand up a server.
api_key: Option<String>,
#[allow(dead_code)]
http: reqwest::Client,
}
impl OpenAIResponsesProvider {
pub fn new(endpoint: EndpointConfig) -> anyhow::Result<Self> {
let api_key = endpoint.resolve_api_key()?;
let http = reqwest::Client::builder()
// Same generous timeout as the chat provider: cortex may
// need to cold-load a model before serving the first
// chunk, which can be tens of seconds. Cancellation
// handles early termination, not timeout.
.timeout(std::time::Duration::from_secs(600))
.build()?;
Ok(Self {
endpoint,
api_key,
http,
})
}
}
#[async_trait]
impl Provider for OpenAIResponsesProvider {
fn name(&self) -> &str {
&self.endpoint.name
}
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>> {
let mut req = self.http.get(self.endpoint.models_url());
if let Some(key) = &self.api_key {
req = req.bearer_auth(key);
}
let resp = req
.send()
.await
.map_err(|e| anyhow::anyhow!("{} list_models: {e}", self.endpoint.name))?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!(
"{} list_models returned {}: {}",
self.endpoint.name,
status,
body
);
}
let body: WireModelsResponse = resp.json().await?;
Ok(body
.data
.into_iter()
.map(|m| ModelInfo {
id: m.id,
display_name: None,
})
.collect())
}
async fn complete(
&self,
request: CompletionRequest,
cancel: CancellationToken,
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>> {
let body = encode_request(&request);
tracing::debug!(
endpoint = %self.endpoint.name,
url = %self.endpoint.responses_url(),
body = %serde_json::to_string(&body).unwrap_or_else(|_| "<unserializable>".into()),
"POST /responses"
);
let mut req = self.http.post(self.endpoint.responses_url()).json(&body);
if let Some(key) = &self.api_key {
req = req.bearer_auth(key);
}
let resp = req
.send()
.await
.map_err(|e| anyhow::anyhow!("{} responses send: {e}", self.endpoint.name))?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!(
"{} responses returned {}: {}",
self.endpoint.name,
status,
body
);
}
let sse = resp.bytes_stream().eventsource();
let stream = decode_stream(sse, cancel);
Ok(Box::pin(stream))
}
}
// ── Request encoding ─────────────────────────────────────────────────
fn encode_request(req: &CompletionRequest) -> Value {
// Pull the system messages out of history into a single
// `instructions` string — the Responses API expects them there,
// not inline as an `input` item. Multiple system messages
// concatenate with blank lines so we don't lose ordering.
let mut instructions: Vec<String> = Vec::new();
let mut input_items: Vec<Value> = Vec::new();
for msg in &req.messages {
if msg.role == Role::System
&& let MessageContent::Text { text } = &msg.content
{
instructions.push(text.clone());
continue;
}
if let Some(item) = encode_message_as_input_item(msg) {
input_items.push(item);
}
}
let mut body = json!({
"model": req.model,
"input": input_items,
"stream": true,
});
if let Value::Object(map) = &mut body {
if !instructions.is_empty() {
map.insert(
"instructions".into(),
Value::String(instructions.join("\n\n")),
);
}
if let Some(t) = req.temperature {
map.insert("temperature".into(), json!(t));
}
if let Some(p) = req.top_p {
map.insert("top_p".into(), json!(p));
}
if let Some(m) = req.max_tokens {
// Responses calls it `max_output_tokens`; preserve the
// semantic (response cap) when we translate.
map.insert("max_output_tokens".into(), json!(m));
}
}
body
}
fn encode_message_as_input_item(msg: &Message) -> Option<Value> {
match (msg.role, &msg.content) {
(Role::System, _) => None, // handled out-of-band as `instructions`
(Role::User, MessageContent::Text { text }) => Some(json!({
"type": "message",
"role": "user",
"content": text,
})),
(Role::User, MessageContent::MultiPart { parts }) => Some(json!({
"type": "message",
"role": "user",
"content": encode_user_parts(parts),
})),
(Role::Assistant, MessageContent::Text { text }) => Some(json!({
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": text,
"annotations": [],
}],
})),
(Role::Assistant, MessageContent::ToolCalls { text, calls }) => {
// Assistant turns that called tools become a sequence of
// items: an optional `message` (any prose alongside the
// call) followed by one `function_call` per call. Mirrors
// OpenAI Responses' "each item is one structural slot"
// shape.
//
// We can't return multiple items from one call site, so
// we encode this by side-stuffing additional items into a
// single composite value and have the caller flatten —
// but that complicates the API. Easier: build the array
// ourselves in the caller path. For now, emit just the
// function_calls (the assistant's prose lives in the next
// turn's chat history anyway because the model isn't
// looking back at its own previous narration). If the
// text is non-empty AND we have calls, we lose the text;
// qwen3 rarely emits prose alongside tool calls so this
// is a deliberate simplification — revisit if it bites.
let _ = text;
// Take the first call only for the moment; multi-call
// turns would need the caller-flattening above.
let call = calls.first()?;
Some(json!({
"type": "function_call",
"call_id": call.id,
"name": call.name,
"arguments": call.arguments,
}))
}
(
Role::Tool,
MessageContent::ToolResult {
tool_call_id,
content,
},
) => Some(json!({
"type": "function_call_output",
"call_id": tool_call_id,
"output": content,
})),
(role, content) => {
tracing::warn!(
?role,
?content,
"openai_responses: unexpected (role, content) shape"
);
None
}
}
}
fn encode_user_parts(parts: &[MessagePart]) -> Value {
let items: Vec<Value> = parts
.iter()
.map(|p| match p {
MessagePart::Text { text } => json!({"type": "input_text", "text": text}),
MessagePart::Image(img) => json!({
"type": "input_image",
"image_url": format!("data:{};base64,{}", img.mime_type, img.data),
}),
})
.collect();
Value::Array(items)
}
// ── Wire types ──────────────────────────────────────────────────────
#[allow(dead_code)] // fields read only when list_models runs against a real endpoint
#[derive(Debug, Deserialize)]
struct WireModelsResponse {
data: Vec<WireModelObject>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct WireModelObject {
id: String,
}
// SSE event payload shapes. We only model the fields we care about;
// `#[serde(default)]` + `Option` everywhere else lets the upstream
// add optional fields without breaking deserialise.
#[derive(Debug, Deserialize, Serialize)]
struct OutputItemAddedEvent {
#[serde(default)]
output_index: u32,
item: OutputItem,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum OutputItem {
Message {
#[serde(default)]
id: Option<String>,
},
FunctionCall {
#[serde(default)]
id: Option<String>,
#[serde(default)]
call_id: Option<String>,
#[serde(default)]
name: Option<String>,
/// Some upstreams populate `arguments` already on the
/// `output_item.added` event for a fully-buffered tool call
/// (i.e. when the model finalised the call before the SSE
/// flush). Capture it so we can emit a single args delta.
#[serde(default)]
arguments: Option<String>,
},
/// `reasoning`, `web_search_call`, etc. We capture-and-ignore
/// any item we don't model; the decoder still emits the
/// outer events correctly.
#[serde(other)]
Unknown,
}
#[derive(Debug, Deserialize, Serialize)]
struct OutputTextDeltaEvent {
#[serde(default)]
item_id: Option<String>,
#[serde(default)]
output_index: u32,
#[serde(default)]
delta: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct FunctionCallArgumentsDeltaEvent {
#[serde(default)]
item_id: Option<String>,
#[serde(default)]
output_index: u32,
#[serde(default)]
delta: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct ResponseCompletedEvent {
response: ResponseShell,
}
#[derive(Debug, Deserialize, Serialize)]
struct ResponseShell {
#[serde(default)]
status: Option<String>,
#[serde(default)]
usage: Option<WireUsage>,
}
#[derive(Debug, Deserialize, Serialize)]
struct WireUsage {
#[serde(default)]
input_tokens: u64,
#[serde(default)]
output_tokens: u64,
#[serde(default)]
total_tokens: u64,
}
// ── Streaming decoder ───────────────────────────────────────────────
/// Translate the named-event Responses SSE into the provider-agnostic
/// [`CompletionEvent`] stream the agent loop expects. The decoder
/// holds per-stream state — output_index → tool-call-index plus
/// the next available tool-call slot — so it can fire
/// `ToolCallStart` exactly once per item.
fn decode_stream<S>(
sse: S,
cancel: CancellationToken,
) -> impl Stream<Item = anyhow::Result<CompletionEvent>>
where
S: Stream<
Item = Result<
eventsource_stream::Event,
eventsource_stream::EventStreamError<reqwest::Error>,
>,
> + Send
+ 'static,
{
async_stream::stream! {
let mut sse = Box::pin(sse);
// Maps an output_index that's a function_call to the tool-call
// slot we hand downstream. Lets us correlate later
// `function_call_arguments.delta` events back to the index
// we already announced on `output_item.added`.
let mut tool_index_by_output: HashMap<u32, usize> = HashMap::new();
let mut next_tool_index: usize = 0;
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::debug!("openai_responses: cancellation requested, ending stream");
break;
}
next = sse.next() => {
let Some(event) = next else { break };
let event = match event {
Ok(e) => e,
Err(e) => {
yield Err(anyhow::anyhow!("SSE transport: {e}"));
break;
}
};
// Event name lives on `event.event`; data is JSON.
let event_name = event.event.as_str();
let data = event.data.as_str();
match event_name {
"response.output_text.delta" => {
match serde_json::from_str::<OutputTextDeltaEvent>(data) {
Ok(d) if !d.delta.is_empty() => {
yield Ok(CompletionEvent::TextDelta(d.delta));
}
Ok(_) => {}
Err(e) => {
tracing::warn!(
error = %e,
raw = %data,
"openai_responses: failed to parse output_text.delta; skipping"
);
}
}
}
"response.output_item.added" => {
match serde_json::from_str::<OutputItemAddedEvent>(data) {
Ok(ev) => {
if let OutputItem::FunctionCall {
id,
call_id,
name,
arguments,
} = ev.item
{
let idx = next_tool_index;
next_tool_index += 1;
tool_index_by_output.insert(ev.output_index, idx);
// Prefer the user-facing
// `call_id` (what gets paired
// with tool results) over the
// internal item `id` when
// both are present. Falls
// back to a synthetic id so
// history bookkeeping never
// breaks.
let final_id = call_id
.or(id)
.unwrap_or_else(|| format!("call_{idx}"));
let final_name = name.unwrap_or_default();
yield Ok(CompletionEvent::ToolCallStart {
index: idx,
id: final_id,
name: final_name,
});
// Some upstreams attach the
// fully-buffered arguments on
// the `output_item.added`
// event itself (rare; happens
// when the model finalised
// before the SSE flush).
// Emit as a single args
// delta if present.
if let Some(args) = arguments
&& !args.is_empty()
{
yield Ok(CompletionEvent::ToolCallArgsDelta {
index: idx,
args_delta: args,
});
}
}
}
Err(e) => {
tracing::warn!(
error = %e,
raw = %data,
"openai_responses: failed to parse output_item.added; skipping"
);
}
}
}
"response.function_call_arguments.delta" => {
match serde_json::from_str::<FunctionCallArgumentsDeltaEvent>(data) {
Ok(ev) => {
let Some(&idx) = tool_index_by_output.get(&ev.output_index)
else {
// Args delta for an item we
// never saw an `output_item.added`
// for. Could happen if the
// upstream reordered events;
// log + skip.
tracing::warn!(
output_index = ev.output_index,
"openai_responses: function_call_arguments.delta for unknown output_index"
);
continue;
};
if !ev.delta.is_empty() {
yield Ok(CompletionEvent::ToolCallArgsDelta {
index: idx,
args_delta: ev.delta,
});
}
}
Err(e) => {
tracing::warn!(
error = %e,
raw = %data,
"openai_responses: failed to parse function_call_arguments.delta; skipping"
);
}
}
}
"response.completed" => {
// Final event. Pull usage + status off
// the response shell. Status maps:
// "completed" → no special handling
// (caller treats as EndTurn),
// "incomplete" → length stop.
let (reason, usage) =
match serde_json::from_str::<ResponseCompletedEvent>(data) {
Ok(ev) => {
let reason = match ev.response.status.as_deref() {
Some("incomplete") => Some("length".to_string()),
_ => Some("stop".to_string()),
};
let usage = ev.response.usage.map(|u| UsageStats {
prompt_tokens: u.input_tokens,
completion_tokens: u.output_tokens,
total_tokens: u.total_tokens,
});
(reason, usage)
}
Err(e) => {
tracing::warn!(
error = %e,
raw = %data,
"openai_responses: failed to parse response.completed; ending stream with EndTurn"
);
(Some("stop".to_string()), None)
}
};
if let Some(u) = usage {
yield Ok(CompletionEvent::Usage(u));
}
yield Ok(CompletionEvent::Finish { reason });
break;
}
// Bookkeeping events we don't need to surface:
// response.created, response.in_progress,
// response.content_part.added/.done,
// response.output_text.done,
// response.output_item.done,
// response.function_call_arguments.done,
// response.reasoning_*. Logged at debug for
// wire-tracing.
other => {
tracing::trace!(
event = other,
"openai_responses: bookkeeping event"
);
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::ToolCall;
use crate::provider::{ImageData, MessagePart};
use futures::stream;
use url::Url;
fn ep() -> EndpointConfig {
EndpointConfig {
name: "test".into(),
base_url: Url::parse("http://localhost:9999/v1").unwrap(),
wire_api: crate::config::WireApi::OpenAiResponses,
default_model: None,
api_key: None,
api_key_env: None,
max_tokens: None,
context_window: None,
}
}
// ── encode_request ──────────────────────────────────────────────
#[test]
fn system_messages_collapse_to_instructions() {
let req = CompletionRequest {
model: "m".into(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text {
text: "you are helpful".into(),
},
},
Message {
role: Role::User,
content: MessageContent::Text { text: "hi".into() },
},
],
tools: vec![],
temperature: Some(0.7),
top_p: None,
max_tokens: Some(256),
};
let body = encode_request(&req);
assert_eq!(body["model"], "m");
assert_eq!(body["instructions"], "you are helpful");
assert_eq!(body["stream"], true);
assert_eq!(body["max_output_tokens"], 256);
assert_eq!(body["temperature"], 0.7);
let input = body["input"].as_array().unwrap();
// System message NOT echoed in input — it's only in
// instructions.
assert_eq!(input.len(), 1);
assert_eq!(input[0]["type"], "message");
assert_eq!(input[0]["role"], "user");
assert_eq!(input[0]["content"], "hi");
}
#[test]
fn multiple_system_messages_concatenate() {
let req = CompletionRequest {
model: "m".into(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text {
text: "first".into(),
},
},
Message {
role: Role::System,
content: MessageContent::Text {
text: "second".into(),
},
},
Message {
role: Role::User,
content: MessageContent::Text { text: "hi".into() },
},
],
tools: vec![],
temperature: None,
top_p: None,
max_tokens: None,
};
let body = encode_request(&req);
assert_eq!(body["instructions"], "first\n\nsecond");
}
#[test]
fn user_multipart_becomes_input_parts_array() {
let req = CompletionRequest {
model: "vl".into(),
messages: vec![Message {
role: Role::User,
content: MessageContent::MultiPart {
parts: vec![
MessagePart::Text {
text: "what's in this?".into(),
},
MessagePart::Image(ImageData {
mime_type: "image/png".into(),
data: "AAA=".into(),
uri: None,
}),
],
},
}],
tools: vec![],
temperature: None,
top_p: None,
max_tokens: None,
};
let body = encode_request(&req);
let content = &body["input"][0]["content"].as_array().unwrap().clone();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "input_text");
assert_eq!(content[0]["text"], "what's in this?");
assert_eq!(content[1]["type"], "input_image");
assert_eq!(content[1]["image_url"], "data:image/png;base64,AAA=");
}
#[test]
fn assistant_text_becomes_output_text_content_part() {
let req = CompletionRequest {
model: "m".into(),
messages: vec![
Message {
role: Role::User,
content: MessageContent::Text { text: "hi".into() },
},
Message {
role: Role::Assistant,
content: MessageContent::Text {
text: "hello there".into(),
},
},
Message {
role: Role::User,
content: MessageContent::Text {
text: "more".into(),
},
},
],
tools: vec![],
temperature: None,
top_p: None,
max_tokens: None,
};
let body = encode_request(&req);
let input = body["input"].as_array().unwrap();
assert_eq!(input.len(), 3);
assert_eq!(input[1]["type"], "message");
assert_eq!(input[1]["role"], "assistant");
assert_eq!(input[1]["content"][0]["type"], "output_text");
assert_eq!(input[1]["content"][0]["text"], "hello there");
}
#[test]
fn tool_calls_and_results_round_trip_via_function_call_items() {
let req = CompletionRequest {
model: "m".into(),
messages: vec![
Message {
role: Role::Assistant,
content: MessageContent::ToolCalls {
text: None,
calls: vec![ToolCall {
id: "call_42".into(),
name: "read_file".into(),
arguments: r#"{"path":"/etc/hostname"}"#.into(),
}],
},
},
Message {
role: Role::Tool,
content: MessageContent::ToolResult {
tool_call_id: "call_42".into(),
content: "host".into(),
},
},
],
tools: vec![],
temperature: None,
top_p: None,
max_tokens: None,
};
let body = encode_request(&req);
let input = body["input"].as_array().unwrap();
assert_eq!(input.len(), 2);
assert_eq!(input[0]["type"], "function_call");
assert_eq!(input[0]["call_id"], "call_42");
assert_eq!(input[0]["name"], "read_file");
assert_eq!(input[0]["arguments"], r#"{"path":"/etc/hostname"}"#);
assert_eq!(input[1]["type"], "function_call_output");
assert_eq!(input[1]["call_id"], "call_42");
assert_eq!(input[1]["output"], "host");
}
// ── decode_stream ───────────────────────────────────────────────
fn sse_event(name: &str, data: &str) -> eventsource_stream::Event {
eventsource_stream::Event {
id: String::new(),
retry: None,
event: name.into(),
data: data.into(),
}
}
async fn collect_events(
items: Vec<eventsource_stream::Event>,
) -> Vec<anyhow::Result<CompletionEvent>> {
let sse = stream::iter(
items
.into_iter()
.map(Ok::<_, eventsource_stream::EventStreamError<reqwest::Error>>),
);
let decoded = decode_stream(sse, CancellationToken::new());
decoded.collect().await
}
#[tokio::test]
async fn decodes_text_then_finish() {
let events = collect_events(vec![
sse_event("response.created", "{}"),
sse_event(
"response.output_text.delta",
r#"{"item_id":"msg_1","output_index":0,"delta":"hel"}"#,
),
sse_event(
"response.output_text.delta",
r#"{"item_id":"msg_1","output_index":0,"delta":"lo"}"#,
),
sse_event(
"response.completed",
r#"{"response":{"status":"completed","usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}"#,
),
])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
let mut iter = events.into_iter();
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "hel"));
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "lo"));
assert!(matches!(iter.next(), Some(CompletionEvent::Usage(u)) if u.total_tokens == 5));
assert!(matches!(
iter.next(),
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "stop"
));
assert!(iter.next().is_none());
}
#[tokio::test]
async fn empty_delta_is_dropped() {
let events = collect_events(vec![
sse_event(
"response.output_text.delta",
r#"{"item_id":"m","output_index":0,"delta":""}"#,
),
sse_event(
"response.completed",
r#"{"response":{"status":"completed"}}"#,
),
])
.await;
let mut completion_events = events.into_iter().map(|r| r.unwrap());
// First event MUST be the Finish — the empty delta dropped.
assert!(matches!(
completion_events.next(),
Some(CompletionEvent::Finish { .. })
));
}
#[tokio::test]
async fn incomplete_status_maps_to_length_finish_reason() {
let events = collect_events(vec![sse_event(
"response.completed",
r#"{"response":{"status":"incomplete"}}"#,
)])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
assert!(matches!(
events.last(),
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "length"
));
}
#[tokio::test]
async fn function_call_items_emit_toolcall_events() {
let events = collect_events(vec![
sse_event(
"response.output_item.added",
r#"{"output_index":0,"item":{"type":"function_call","id":"item_1","call_id":"call_xyz","name":"read_file"}}"#,
),
sse_event(
"response.function_call_arguments.delta",
r#"{"item_id":"item_1","output_index":0,"delta":"{\"path"}"#,
),
sse_event(
"response.function_call_arguments.delta",
r#"{"item_id":"item_1","output_index":0,"delta":"\":\"/etc/hostname\"}"}"#,
),
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
let mut iter = events.into_iter();
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallStart { index: 0, ref id, ref name })
if id == "call_xyz" && name == "read_file"
));
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
if args_delta == r#"{"path"#
));
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
if args_delta == r#"":"/etc/hostname"}"#
));
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
}
#[tokio::test]
async fn function_call_added_with_inline_arguments_emits_single_args_delta() {
// Some upstreams (rare) include the fully-buffered arguments
// on the `output_item.added` event when the model finalised
// the call before SSE flush. Verify both ToolCallStart and a
// single args delta fire.
let events = collect_events(vec![
sse_event(
"response.output_item.added",
r#"{"output_index":0,"item":{"type":"function_call","call_id":"call_a","name":"f","arguments":"{\"x\":1}"}}"#,
),
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
let mut iter = events.into_iter();
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallStart { .. })
));
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
if args_delta == r#"{"x":1}"#
));
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
}
#[tokio::test]
async fn cancellation_ends_stream_promptly() {
// Hand the decoder an empty stream + a triggered cancellation
// token; it should terminate without yielding anything.
let sse = stream::iter(Vec::<
Result<eventsource_stream::Event, eventsource_stream::EventStreamError<reqwest::Error>>,
>::new());
let cancel = CancellationToken::new();
cancel.cancel();
let decoded = decode_stream(sse, cancel);
let events: Vec<_> = decoded.collect().await;
assert!(events.is_empty());
}
#[tokio::test]
async fn malformed_event_payload_is_skipped() {
let events = collect_events(vec![
sse_event("response.output_text.delta", "{not valid json"),
sse_event(
"response.output_text.delta",
r#"{"item_id":"m","output_index":0,"delta":"ok"}"#,
),
sse_event(
"response.completed",
r#"{"response":{"status":"completed"}}"#,
),
])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
// First text delta dropped; second one fires.
assert!(
events
.iter()
.any(|e| matches!(e, CompletionEvent::TextDelta(t) if t == "ok"))
);
// No errors yielded (parse failures are warn-and-skip).
assert!(
events
.iter()
.all(|e| !matches!(e, CompletionEvent::Finish { reason: None }))
);
}
#[test]
fn provider_construction_is_cheap() {
let _ = OpenAIResponsesProvider::new(ep()).unwrap();
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,188 +0,0 @@
//! Per-session state for the ACP agent loop.
//!
//! Concurrency:
//!
//! - [`SessionStore`] is an `Arc<RwLock<HashMap<SessionId, …>>>`. The map
//! itself is read-mostly: it changes only on `session/new` and never
//! shrinks during Stage 2, so an `RwLock` keeps concurrent reads
//! contention-free.
//! - Each session is wrapped in its own `Arc<Mutex<SessionState>>`. Holding
//! one session's lock doesn't block requests against any other session,
//! which matters once a client opens multiple sessions in parallel.
//!
//! All operations hold a lock only long enough to copy out (or mutate) the
//! state they need — never across an `await` that drives the upstream
//! provider stream.
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use agent_client_protocol::schema::{SessionId, SessionModeId};
use tokio::sync::{Mutex, RwLock};
use tokio_util::sync::CancellationToken;
use crate::provider::Message;
/// Mode id advertised as the gated default. Writes / bash prompt for
/// permission via `session/request_permission`.
pub const MODE_DEFAULT: &str = "default";
/// Mode id advertised as "auto-allow everything". Matches the
/// favorite name (`bypassPermissions`) Zed clients tend to reference.
pub const MODE_BYPASS: &str = "bypassPermissions";
/// Mode id for read-and-plan-only operation. The model may read files
/// and list directories freely, may write *only* into the per-project
/// plan directory under `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`,
/// and cannot run shell commands. Designed for "draft the
/// implementation plan, then I'll review and let you execute" flows.
pub const MODE_PLAN: &str = "plan";
/// State carried for a single ACP session.
///
/// Mutated under `Mutex<SessionState>`; never share a clone across
/// tasks expecting to see the same `cancel` token — clone the token
/// explicitly when handing it to the streaming task.
#[derive(Debug)]
pub struct SessionState {
/// Conversation history in chronological order (user / assistant
/// turns). The system prompt is *not* stored here — it's built
/// fresh per request so any cwd / config changes take effect.
pub history: Vec<Message>,
/// Working directory the client opened the session against. Used
/// by [`crate::prompt::build_system_prompt`] and (Stage 3) by
/// filesystem tools.
pub cwd: PathBuf,
/// Currently-selected model id. Format is either a bare model id
/// (resolved against the default endpoint) or `endpoint:model`.
/// Mutated by `session/set_model` in Stage 4; Stage 2 sets it
/// once at session creation and never changes it.
pub model_id: String,
/// Cancellation handle for the in-flight prompt, if any. A fresh
/// token is installed at the start of every `session/prompt`
/// request; `session/cancel` fires this one. Between prompts the
/// token is "spent" — firing it does nothing — which is fine,
/// `session/cancel` is a no-op when there's nothing to cancel.
pub cancel: CancellationToken,
/// Permission gating mode. Stage 3 advertises two ids in
/// `NewSessionResponse.modes`: [`MODE_DEFAULT`] (writes / bash
/// prompt the user) and [`MODE_BYPASS`] (auto-allow). Mutated by
/// `session/set_mode`.
pub mode_id: SessionModeId,
}
impl SessionState {
pub fn new(cwd: PathBuf, model_id: String) -> Self {
Self {
history: Vec::new(),
cwd,
model_id,
cancel: CancellationToken::new(),
mode_id: SessionModeId::new(MODE_DEFAULT),
}
}
}
/// Concurrent map of live sessions.
///
/// Cloning is cheap (`Arc` bump). Pass clones into every handler that
/// needs session access; never hold a clone across an `.await` that
/// could outlive the request.
pub type SessionStore = Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>;
/// Fresh, empty session store.
pub fn new_store() -> SessionStore {
Arc::new(RwLock::new(HashMap::new()))
}
/// Look up a session by id. Returns `None` if no such session is registered.
pub async fn get(store: &SessionStore, id: &SessionId) -> Option<Arc<Mutex<SessionState>>> {
store.read().await.get(id).cloned()
}
/// Register a fresh session. Overwrites any prior entry with the same id
/// (which should never happen — ids are uniquely generated by the agent).
pub async fn insert(store: &SessionStore, id: SessionId, state: SessionState) {
store.write().await.insert(id, Arc::new(Mutex::new(state)));
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{MessageContent, Role};
fn id(s: &str) -> SessionId {
SessionId::new(s)
}
#[tokio::test]
async fn insert_then_get_round_trip() {
let store = new_store();
let state = SessionState::new(PathBuf::from("/tmp"), "m".into());
insert(&store, id("s1"), state).await;
let got = get(&store, &id("s1")).await.expect("session present");
let locked = got.lock().await;
assert_eq!(locked.cwd, PathBuf::from("/tmp"));
assert_eq!(locked.model_id, "m");
assert!(locked.history.is_empty());
}
#[tokio::test]
async fn missing_session_is_none() {
let store = new_store();
assert!(get(&store, &id("nope")).await.is_none());
}
#[tokio::test]
async fn history_is_per_session() {
let store = new_store();
insert(
&store,
id("a"),
SessionState::new(PathBuf::from("/a"), "m".into()),
)
.await;
insert(
&store,
id("b"),
SessionState::new(PathBuf::from("/b"), "m".into()),
)
.await;
// Appending to a's history must not affect b's.
get(&store, &id("a"))
.await
.unwrap()
.lock()
.await
.history
.push(Message {
role: Role::User,
content: MessageContent::Text {
text: "hello".into(),
},
});
assert_eq!(
get(&store, &id("a"))
.await
.unwrap()
.lock()
.await
.history
.len(),
1
);
assert_eq!(
get(&store, &id("b"))
.await
.unwrap()
.lock()
.await
.history
.len(),
0
);
}
}

View File

@@ -1,462 +0,0 @@
//! On-disk session persistence for `session/load` support.
//!
//! Storage layout:
//!
//! ```text
//! $XDG_DATA_HOME/helexa-acp/sessions/{session_id}.json
//! ```
//!
//! (Fallback to `~/.local/share/helexa-acp/sessions/` when
//! `$XDG_DATA_HOME` is unset.) One JSON file per session. Writes
//! happen at the end of every `session/prompt` round through
//! [`save`], using tempfile-plus-rename so a crash mid-write can't
//! corrupt the store. Reads happen on `session/load` via [`load`].
//!
//! No compaction, no rotation: files accumulate until the user
//! cleans them up. That's deliberate — disk is cheap, and the
//! resume-on-restart workflow matters more than tidiness. The
//! [`SESSIONS_DIRNAME`] subdirectory is created lazily on first
//! save so an unprivileged install path never errors at startup.
use std::path::PathBuf;
use std::time::SystemTime;
use agent_client_protocol::schema::SessionId;
use serde::{Deserialize, Serialize};
use crate::provider::Message;
const APP_DIRNAME: &str = "helexa-acp";
const SESSIONS_DIRNAME: &str = "sessions";
const PLANS_DIRNAME: &str = "plans";
/// The shape persisted to disk for one session. Only what we can't
/// rebuild from the running config goes in here: the conversation
/// history, the mode toggle, the model id, and the cwd-at-creation.
///
/// `created_at` / `updated_at` are seconds-since-epoch — cheap to
/// compare, no third-party time crate, and stable across runs.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedSession {
pub session_id: String,
pub cwd: PathBuf,
pub model_id: String,
pub mode_id: String,
pub history: Vec<Message>,
pub created_at: u64,
pub updated_at: u64,
}
/// Resolve the directory that holds session JSON files. Honors
/// `$XDG_DATA_HOME`; falls back to `~/.local/share/helexa-acp/sessions/`.
/// Returns `None` if neither is resolvable (no `HOME` set — possible
/// in stripped-down container environments).
pub fn sessions_dir() -> Option<PathBuf> {
let base = std::env::var("XDG_DATA_HOME")
.ok()
.filter(|s| !s.is_empty())
.map(PathBuf::from)
.or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| PathBuf::from(h).join(".local").join("share"))
})?;
Some(base.join(APP_DIRNAME).join(SESSIONS_DIRNAME))
}
/// Atomic save into the default sessions directory.
pub fn save(session: &PersistedSession) -> anyhow::Result<()> {
let dir = sessions_dir()
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
save_to_dir(&dir, session)
}
/// Load from the default sessions directory.
pub fn load(session_id: &SessionId) -> anyhow::Result<PersistedSession> {
let dir = sessions_dir()
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
load_from_dir(&dir, session_id)
}
/// Atomic save into an explicit directory. Writes to
/// `{id}.json.tmp` then renames over `{id}.json`. Creates the
/// target directory if it doesn't exist. Split from [`save`] so
/// unit tests can target a per-test scratch dir without mutating
/// process-global env vars.
pub fn save_to_dir(dir: &std::path::Path, session: &PersistedSession) -> anyhow::Result<()> {
std::fs::create_dir_all(dir).map_err(|e| anyhow::anyhow!("create {}: {e}", dir.display()))?;
let safe = sanitize_id(&session.session_id);
let final_path = dir.join(format!("{safe}.json"));
let tmp_path = dir.join(format!("{safe}.json.tmp"));
let json = serde_json::to_string_pretty(session)?;
std::fs::write(&tmp_path, json)
.map_err(|e| anyhow::anyhow!("write {}: {e}", tmp_path.display()))?;
std::fs::rename(&tmp_path, &final_path)
.map_err(|e| anyhow::anyhow!("rename → {}: {e}", final_path.display()))?;
Ok(())
}
/// Load from an explicit directory. Returns a friendly error
/// message when the session id has no file on disk so the caller
/// can map it to a clean ACP error response.
pub fn load_from_dir(
dir: &std::path::Path,
session_id: &SessionId,
) -> anyhow::Result<PersistedSession> {
let safe = sanitize_id(session_id.0.as_ref());
let path = dir.join(format!("{safe}.json"));
let bytes = std::fs::read(&path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
anyhow::anyhow!("no persisted session at {}", path.display())
} else {
anyhow::anyhow!("read {}: {e}", path.display())
}
})?;
let session: PersistedSession = serde_json::from_slice(&bytes)
.map_err(|e| anyhow::anyhow!("parse {}: {e}", path.display()))?;
Ok(session)
}
/// List all persisted sessions, optionally filtered by `cwd`. Used
/// by the `session/list` handler so a client (Zed) can find the
/// session that belongs to the workspace it's reopening.
///
/// `filter_cwd = None` returns every session on disk. `Some(path)`
/// returns only sessions whose persisted `cwd` is exactly equal.
///
/// Files that fail to parse are skipped with a warning rather than
/// aborting the whole list — one corrupt session shouldn't make
/// the resume picker unusable.
pub fn list(filter_cwd: Option<&std::path::Path>) -> anyhow::Result<Vec<PersistedSession>> {
let dir = sessions_dir()
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
list_in_dir(&dir, filter_cwd)
}
/// Explicit-dir variant for tests, mirroring [`save_to_dir`] /
/// [`load_from_dir`].
pub fn list_in_dir(
dir: &std::path::Path,
filter_cwd: Option<&std::path::Path>,
) -> anyhow::Result<Vec<PersistedSession>> {
let read = match std::fs::read_dir(dir) {
Ok(r) => r,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
Err(e) => return Err(anyhow::anyhow!("read_dir {}: {e}", dir.display())),
};
let mut out = Vec::new();
for entry in read.flatten() {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("json") {
continue;
}
match std::fs::read(&path).and_then(|bytes| {
serde_json::from_slice::<PersistedSession>(&bytes).map_err(std::io::Error::other)
}) {
Ok(session) => {
if let Some(want) = filter_cwd
&& session.cwd != want
{
continue;
}
out.push(session);
}
Err(e) => {
tracing::warn!(
path = %path.display(),
error = %e,
"store: skipping unparseable session file"
);
}
}
}
// Most-recent first by updated_at.
out.sort_by_key(|s| std::cmp::Reverse(s.updated_at));
Ok(out)
}
/// Seconds-since-epoch, saturating to 0 if the system clock is
/// behind epoch (which shouldn't happen but the type system
/// requires a fallible read).
pub fn now_secs() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
/// Root directory for plan-mode artefacts. Mirrors [`sessions_dir`]
/// but under `…/helexa-acp/plans/` so plans and conversation
/// transcripts are siblings, not nested.
pub fn plans_root() -> Option<PathBuf> {
sessions_dir().and_then(|s| s.parent().map(|p| p.join(PLANS_DIRNAME)))
}
/// Per-project plan directory:
/// `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. The id derives
/// from the session's cwd so plans for the same project survive
/// across cwd-changes (a `/home/foo/git/bar` ↔ symlinked
/// `/srv/checkout/bar` would technically diverge, accepted as a
/// won't-fix corner case).
pub fn plan_dir_for(cwd: &std::path::Path) -> Option<PathBuf> {
plans_root().map(|root| root.join(project_id_for(cwd)))
}
/// Deterministic, human-readable project identifier. Format:
/// `<basename>-<8-hex>` where the 8-hex suffix is FNV-1a of the
/// full path. Basename keeps the path skim-readable when poking
/// around `$XDG_DATA_HOME` by hand; the hash suffix disambiguates
/// repos that share a final path component (e.g. multiple
/// `/.../checkout/beat` checkouts).
///
/// FNV-1a rather than `std::collections::hash::DefaultHasher`
/// because the latter (SipHash) reseeds per process, so it'd give
/// us a different project_id on every run.
pub fn project_id_for(cwd: &std::path::Path) -> String {
let basename = cwd
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
let sanitised: String = basename
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect();
let hash = fnv1a_32(cwd.to_string_lossy().as_bytes());
format!("{sanitised}-{hash:08x}")
}
/// FNV-1a (32-bit). Deterministic, no third-party crate. Used for
/// project ids only — not cryptographic.
fn fnv1a_32(bytes: &[u8]) -> u32 {
let mut h: u32 = 0x811c_9dc5;
for b in bytes {
h ^= u32::from(*b);
h = h.wrapping_mul(0x0100_0193);
}
h
}
/// Format seconds-since-epoch as an ISO 8601 / RFC 3339 string
/// (`YYYY-MM-DDTHH:MM:SSZ`) for `SessionInfo.updated_at`. Returns
/// `None` for values outside the representable range, in which
/// case the caller should omit the field.
pub fn unix_to_iso8601(secs: u64) -> Option<String> {
use chrono::TimeZone;
let dt = chrono::Utc.timestamp_opt(secs as i64, 0).single()?;
Some(dt.to_rfc3339_opts(chrono::SecondsFormat::Secs, true))
}
/// Strip anything that isn't a safe filename character so a
/// mischievous (or just unconventional) session id can't escape
/// the sessions directory.
fn sanitize_id(id: &str) -> String {
id.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{MessageContent, Role};
/// Unique scratch dir per test invocation. We use this dir
/// directly with the `*_to_dir` / `*_from_dir` functions so
/// the tests never mutate `$XDG_DATA_HOME` — that env var
/// would race across the parallel test harness.
fn unique_dir() -> PathBuf {
let base = std::env::var("CARGO_TARGET_TMPDIR")
.ok()
.map(PathBuf::from)
.unwrap_or_else(std::env::temp_dir);
let pid = std::process::id();
let nanos = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
let dir = base.join(format!("helexa-acp-store-test-{pid}-{nanos}"));
std::fs::create_dir_all(&dir).expect("create test dir");
dir
}
fn sample(id: &str) -> PersistedSession {
PersistedSession {
session_id: id.into(),
cwd: PathBuf::from("/home/me/proj"),
model_id: "Qwen/Qwen3.6-27B".into(),
mode_id: "default".into(),
history: vec![
Message {
role: Role::User,
content: MessageContent::Text {
text: "hello".into(),
},
},
Message {
role: Role::Assistant,
content: MessageContent::Text { text: "hi".into() },
},
],
created_at: 1_700_000_000,
updated_at: 1_700_000_001,
}
}
#[test]
fn round_trip_save_then_load() {
let dir = unique_dir();
save_to_dir(&dir, &sample("hxa-1")).expect("save");
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
assert_eq!(loaded.session_id, "hxa-1");
assert_eq!(loaded.cwd, PathBuf::from("/home/me/proj"));
assert_eq!(loaded.history.len(), 2);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_missing_session_errors_with_not_found_message() {
let dir = unique_dir();
let err = load_from_dir(&dir, &SessionId::new("nope")).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("no persisted session"),
"want NotFound, got: {msg}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_overwrites_existing_atomically() {
let dir = unique_dir();
save_to_dir(&dir, &sample("hxa-1")).expect("save");
let mut updated = sample("hxa-1");
updated.history.push(Message {
role: Role::User,
content: MessageContent::Text {
text: "third turn".into(),
},
});
updated.updated_at = 1_700_000_500;
save_to_dir(&dir, &updated).expect("re-save");
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
assert_eq!(loaded.history.len(), 3);
assert_eq!(loaded.updated_at, 1_700_000_500);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_then_load_preserves_tool_calls_and_results() {
use crate::provider::ToolCall;
let dir = unique_dir();
let mut session = sample("hxa-2");
session.history.push(Message {
role: Role::Assistant,
content: MessageContent::ToolCalls {
text: Some("calling".into()),
calls: vec![ToolCall {
id: "call_0".into(),
name: "read_file".into(),
arguments: r#"{"path":"/etc/hostname"}"#.into(),
}],
},
});
session.history.push(Message {
role: Role::Tool,
content: MessageContent::ToolResult {
tool_call_id: "call_0".into(),
content: "host".into(),
},
});
save_to_dir(&dir, &session).expect("save");
let loaded = load_from_dir(&dir, &SessionId::new("hxa-2")).expect("load");
assert_eq!(loaded.history.len(), 4);
match &loaded.history[2].content {
MessageContent::ToolCalls { calls, .. } => {
assert_eq!(calls[0].name, "read_file");
}
other => panic!("expected ToolCalls, got {other:?}"),
}
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn list_filters_by_cwd_and_sorts_recent_first() {
let dir = unique_dir();
let mut a = sample("a");
a.cwd = PathBuf::from("/home/me/proj-x");
a.updated_at = 1_700_000_010;
let mut b = sample("b");
b.cwd = PathBuf::from("/home/me/proj-x");
b.updated_at = 1_700_000_020;
let mut c = sample("c");
c.cwd = PathBuf::from("/home/me/elsewhere");
c.updated_at = 1_700_000_030;
save_to_dir(&dir, &a).unwrap();
save_to_dir(&dir, &b).unwrap();
save_to_dir(&dir, &c).unwrap();
let proj_x = PathBuf::from("/home/me/proj-x");
let list = list_in_dir(&dir, Some(&proj_x)).unwrap();
let ids: Vec<&str> = list.iter().map(|s| s.session_id.as_str()).collect();
// Filtered to proj-x; b before a because b is more recent.
assert_eq!(ids, vec!["b", "a"]);
let all = list_in_dir(&dir, None).unwrap();
assert_eq!(all.len(), 3);
// Global list still sorted recent-first across all cwds.
assert_eq!(all[0].session_id, "c");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn list_returns_empty_for_missing_dir() {
let dir = unique_dir().join("does-not-exist");
let list = list_in_dir(&dir, None).unwrap();
assert!(list.is_empty());
}
#[test]
fn list_skips_unparseable_files() {
let dir = unique_dir();
save_to_dir(&dir, &sample("good")).unwrap();
std::fs::write(dir.join("garbage.json"), b"{not valid json").unwrap();
let list = list_in_dir(&dir, None).unwrap();
// Garbage skipped; good survives.
assert_eq!(list.len(), 1);
assert_eq!(list[0].session_id, "good");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn iso8601_formats_unix_seconds() {
// 2024-01-01T00:00:00Z is 1704067200 unix seconds.
assert_eq!(
unix_to_iso8601(1_704_067_200),
Some("2024-01-01T00:00:00Z".into())
);
assert_eq!(unix_to_iso8601(0), Some("1970-01-01T00:00:00Z".into()));
}
#[test]
fn sanitize_id_rejects_path_traversal() {
// `../../etc/passwd` — 6 non-alnum chars before "etc"
// (`.`, `.`, `/`, `.`, `.`, `/`), one between, none
// after, none before nothing. Every disallowed char
// collapses to `_`.
assert_eq!(sanitize_id("../../etc/passwd"), "______etc_passwd");
assert_eq!(sanitize_id("ok-name_42"), "ok-name_42");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,300 +0,0 @@
//! Tool schemas sent to the upstream model on every completion.
//!
//! These are the OpenAI-function-style declarations the LLM sees in
//! `CompletionRequest.tools`; the runtime dispatch happens in
//! [`crate::tool_runner`]. Keeping declarations and execution in
//! separate modules makes it easy to add a tool without touching the
//! runner, and vice versa.
//!
//! Stage 3 ships five: filesystem read / write / edit, directory
//! listing, and `bash`. Image generation, web fetch, MCP-derived
//! tools, etc. are out of scope here.
use serde_json::json;
use crate::provider::ToolSpec;
pub const READ_FILE: &str = "read_file";
pub const WRITE_FILE: &str = "write_file";
pub const EDIT_FILE: &str = "edit_file";
pub const LIST_DIR: &str = "list_dir";
pub const BASH: &str = "bash";
/// Build the static tool list passed to the model on every prompt.
/// Cheap — the JSON Schema fragments are constructed each call but
/// the bodies are small constants. If this ever shows up in a
/// profile we can `OnceLock` the Vec.
pub fn all_tools() -> Vec<ToolSpec> {
vec![
ToolSpec {
name: READ_FILE.to_string(),
description: "Read the contents of a text file. Returns the file's text.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file."
},
"line": {
"type": "integer",
"description": "Optional 1-based line number to start reading from.",
"minimum": 1
},
"limit": {
"type": "integer",
"description": "Optional maximum number of lines to read.",
"minimum": 1
}
},
"required": ["path"],
"additionalProperties": false
}),
},
ToolSpec {
name: WRITE_FILE.to_string(),
description: "Write text content to a file, replacing any existing contents. \
Creates the file (and parent directories) if needed."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file."
},
"content": {
"type": "string",
"description": "Full new contents of the file."
}
},
"required": ["path", "content"],
"additionalProperties": false
}),
},
ToolSpec {
name: EDIT_FILE.to_string(),
description: "Replace one exact substring in a file with another. \
Fails if `old_text` does not appear in the file, or appears more than once. \
Use multiple edit_file calls for multiple edits."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file."
},
"old_text": {
"type": "string",
"description": "Exact text fragment to replace. Must be unique within the file."
},
"new_text": {
"type": "string",
"description": "Replacement text."
}
},
"required": ["path", "old_text", "new_text"],
"additionalProperties": false
}),
},
ToolSpec {
name: LIST_DIR.to_string(),
description:
"List the entries of a directory. Returns names and a (f|d|l) kind per entry."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the directory."
}
},
"required": ["path"],
"additionalProperties": false
}),
},
ToolSpec {
name: BASH.to_string(),
description: "Run a shell command via `sh -c`. \
Returns combined stdout+stderr and the exit status. \
The command runs in the session's working directory unless `cwd` is given."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Shell command line, evaluated by `sh -c`."
},
"cwd": {
"type": "string",
"description": "Optional absolute path to run the command from."
}
},
"required": ["command"],
"additionalProperties": false
}),
},
]
}
/// Try to infer which tool was intended from the shape of an
/// `arguments` object alone. Used by the agent when the model
/// emits a `<tool_call>` whose JSON has the right arguments but a
/// missing or invalid top-level `name` field — a recurring
/// Qwen3.6-27B failure mode.
///
/// Returns `Some(name)` only when the argument keys uniquely match
/// exactly one tool in the catalogue. Ambiguous shapes (`{path}`
/// alone could be either [`READ_FILE`] or [`LIST_DIR`]) return
/// `None` so the caller surfaces a Failed-card and lets the model
/// retry rather than guessing wrong.
///
/// Inference table (key set → tool):
///
/// | Keys | Tool |
/// |---------------------------------------|--------------|
/// | `{command}` or `{command, cwd}` | `bash` |
/// | `{path, content}` | `write_file` |
/// | `{path, old_text, new_text}` | `edit_file` |
/// | `{path}` / `{path, line}` / `{path, line, limit}` | *ambiguous* — None |
/// | (anything else) | None |
pub fn infer_tool_name(arguments: &serde_json::Value) -> Option<&'static str> {
let obj = arguments.as_object()?;
let keys: std::collections::HashSet<&str> = obj.keys().map(|s| s.as_str()).collect();
// `command` is unique to bash. Allow the optional `cwd` arg
// alongside but nothing else (any unrecognised keys → bail and
// let the model retry rather than misroute).
if keys.contains("command") && keys.iter().all(|k| matches!(*k, "command" | "cwd")) {
return Some(BASH);
}
// `content` is unique to write_file.
if keys.contains("content") && keys.contains("path") && keys.len() == 2 {
return Some(WRITE_FILE);
}
// `old_text` + `new_text` are unique to edit_file.
if keys.contains("old_text")
&& keys.contains("new_text")
&& keys.contains("path")
&& keys.len() == 3
{
return Some(EDIT_FILE);
}
// `{path}` / `{path, line}` / `{path, line, limit}` overlap
// between read_file (file contents) and list_dir (directory
// contents). No safe inference — refuse.
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_tools_has_five_named_entries() {
let tools = all_tools();
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert_eq!(
names,
vec![READ_FILE, WRITE_FILE, EDIT_FILE, LIST_DIR, BASH]
);
}
#[test]
fn infer_bash_from_command_only() {
let args = serde_json::json!({"command": "ls /tmp"});
assert_eq!(infer_tool_name(&args), Some(BASH));
}
#[test]
fn infer_bash_from_command_and_cwd() {
let args = serde_json::json!({"command": "ls", "cwd": "/tmp"});
assert_eq!(infer_tool_name(&args), Some(BASH));
}
#[test]
fn infer_bash_from_mkdir_like_real_failure() {
// Lifted verbatim from the agent failure that motivated
// this helper (helexa-acp.log @ 10:03:11).
let args = serde_json::json!({
"command": "mkdir -p /home/grenade/git/beat/beat/doc/plan/{01-discovery,02-segmentation,03-description,04-summary,05-output}"
});
assert_eq!(infer_tool_name(&args), Some(BASH));
}
#[test]
fn infer_write_file() {
let args = serde_json::json!({"path": "/tmp/x", "content": "hi"});
assert_eq!(infer_tool_name(&args), Some(WRITE_FILE));
}
#[test]
fn infer_edit_file() {
let args = serde_json::json!({
"path": "/tmp/x", "old_text": "a", "new_text": "b"
});
assert_eq!(infer_tool_name(&args), Some(EDIT_FILE));
}
#[test]
fn refuse_ambiguous_path_only() {
let args = serde_json::json!({"path": "/tmp/x"});
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn refuse_ambiguous_path_with_optionals() {
// read_file accepts these optionals; list_dir doesn't —
// but Qwen wouldn't reliably emit them either, so we
// can't use their presence to disambiguate. Refuse.
let args = serde_json::json!({"path": "/tmp/x", "line": 1, "limit": 50});
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn refuse_command_with_extra_unknown_keys() {
// Defence in depth: an unrecognised key alongside
// `command` means we don't really know what tool the
// model wanted; refuse rather than guess.
let args = serde_json::json!({"command": "ls", "extra": "?"});
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn refuse_empty_args() {
let args = serde_json::json!({});
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn refuse_non_object_args() {
let args = serde_json::json!("not an object");
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn every_tool_has_an_object_parameter_schema() {
for tool in all_tools() {
let ty = tool.parameters.get("type").and_then(|v| v.as_str());
assert_eq!(
ty,
Some("object"),
"tool {} parameters.type must be \"object\"",
tool.name
);
assert!(
tool.parameters.get("properties").is_some(),
"tool {} missing properties",
tool.name
);
assert!(
tool.parameters.get("required").is_some(),
"tool {} missing required list",
tool.name
);
}
}
}

View File

@@ -12,36 +12,6 @@ path = "src/lib.rs"
name = "neuron"
path = "src/main.rs"
[features]
default = []
# Enables CUDA acceleration in candle and the cudarc/nccl bindings the
# TP worker pool uses. Without this feature, candle compiles for CPU
# only, Device::new_cuda calls fall back to CPU, and TP Init/sanity
# requests return Error{kind="cuda_feature_not_enabled"}.
cuda = [
"candle-core/cuda",
"candle-core/nccl",
"candle-nn/cuda",
"candle-transformers/cuda",
"dep:cudarc",
"dep:half",
"dep:cudaforge",
]
# Use cuDNN for convolution / attention kernels. Requires CUDA.
cudnn = [
"cuda",
"candle-core/cudnn",
"candle-nn/cudnn",
"candle-transformers/cudnn",
]
# FlashAttention kernels. Requires CUDA.
flash-attn = [
"cuda",
"candle-transformers/flash-attn",
]
# Reserved for GPU-only integration tests in later stages.
cuda-integration = ["cuda"]
[dependencies]
cortex-core.workspace = true
tokio.workspace = true
@@ -54,54 +24,9 @@ tracing-subscriber.workspace = true
anyhow.workspace = true
async-trait.workspace = true
clap.workspace = true
thiserror.workspace = true
futures.workspace = true
tokio-stream.workspace = true
figment.workspace = true
toml.workspace = true
# candle for in-process inference. CUDA support is gated behind the
# crate's `cuda` feature (default off) so the workspace builds on
# non-CUDA hosts and CI runners.
candle-core = "0.10.2"
candle-nn = "0.10.2"
candle-transformers = "0.10.2"
# Direct dep on cudarc (matching candle's transitive version) so the
# TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on
# the `cuda` feature; same toolchain requirement as candle's CUDA path.
cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] }
# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle
# storages. Matches candle-core's pinned major version to avoid double-
# compiling the `half` crate at conflicting versions.
half = { version = "2.5", optional = true }
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
hf-hub = { version = "0.4", features = ["tokio"] }
# Jinja-compatible template renderer for the model's
# `tokenizer_config.json::chat_template`. Hugging Face's chat
# templates use a strict subset of Jinja2 that minijinja supports
# out of the box. ~80KB compiled; pure Rust, no async surface.
# Features: `builtins` for the `is defined` / `default` filters HF
# templates use; `json` for `tojson` (some Qwen3 templates emit
# tool definitions via tojson); `serde` so we can hand it a
# serde_json::Value as the context.
minijinja = { version = "2", features = ["builtins", "json", "serde"] }
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
# tp `fused_load` module to read per-rank slices of fused QKV tensors
# without materialising the full tensor on device.
safetensors = "0.7"
[dev-dependencies]
tokio = { workspace = true, features = ["test-util"] }
reqwest.workspace = true
tempfile = "3"
[build-dependencies]
# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a`
# under the `cuda` feature. Matches mistralrs's upstream build setup
# (their `mistralrs-core/build.rs` uses the same constructor).
cudaforge = { version = "0.1", optional = true }
[package.metadata.docs.rs]
# Skip the CUDA path on docs.rs (it lacks nvcc).
no-default-features = true

View File

@@ -1,66 +0,0 @@
//! Build script: compile the CUDA kernels in `src/cuda/*.cu` into a
//! static library and link it under the `cuda` feature.
//!
//! Patterned on `EricLBuehler/mistral.rs::mistralrs-core/build.rs` —
//! same `cudaforge::KernelBuilder` invocation, same NVCC flag set.
fn main() {
#[cfg(feature = "cuda")]
{
use std::path::PathBuf;
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/cuda/");
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let mut builder = cudaforge::KernelBuilder::new()
.source_glob("src/cuda/*.cu")
.out_dir(&build_dir)
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--compiler-options")
.arg("-fPIC");
// sm_<80 doesn't have bf16 intrinsics for WMMA — gate the
// bf16-only kernels off in that case. (Mirrors upstream.)
if let Some(compute_cap) = builder.get_compute_cap()
&& compute_cap < 80
{
builder = builder.arg("-DNO_BF16_KERNEL");
}
let target = std::env::var("TARGET").unwrap();
let out_file = if target.contains("msvc") {
build_dir.join("neuroncuda.lib")
} else {
build_dir.join("libneuroncuda.a")
};
builder
.build_lib(out_file)
.expect("neuron cuda build failed");
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=neuroncuda");
println!("cargo:rustc-link-lib=dylib=cudart");
if target.contains("msvc") {
// No extra runtime library needed.
} else if target.contains("apple")
|| target.contains("freebsd")
|| target.contains("openbsd")
{
println!("cargo:rustc-link-lib=dylib=c++");
} else if target.contains("android") {
println!("cargo:rustc-link-lib=dylib=c++_shared");
} else {
println!("cargo:rustc-link-lib=dylib=stdc++");
}
}
}

View File

@@ -1,93 +0,0 @@
//! Activation-time pre-warm progress tracking.
//!
//! Wraps the [`ActivationStatus`] snapshot in an async RwLock so the
//! background pre-warm task can update it per-model while the
//! `/health` handler reads coherent snapshots. The tracker exists
//! because `default_models` loading moved from synchronous-before-bind
//! to background-after-bind on 2026-05-26: the listener is up
//! immediately, but `/health` now needs to tell callers which of the
//! configured defaults are still warming.
use cortex_core::discovery::{ActivationState, ActivationStatus, PreWarmFailure};
use cortex_core::harness::ModelSpec;
use tokio::sync::RwLock;
/// Shared, async-safe handle to the daemon's activation progress.
///
/// Construct once in `main` with the configured `default_models` so
/// the initial `pending` list matches the spec; clone the `Arc` into
/// the `NeuronState` for HTTP handlers and into the spawned pre-warm
/// task for updates.
pub struct ActivationTracker {
inner: RwLock<ActivationStatus>,
}
impl ActivationTracker {
/// Build a tracker primed with one entry per spec. An empty spec
/// list yields a `Ready` tracker — no point reporting PreWarming
/// when there's nothing queued.
pub fn new(default_models: &[ModelSpec]) -> Self {
let pending: Vec<String> = default_models.iter().map(|s| s.model_id.clone()).collect();
let state = if pending.is_empty() {
ActivationState::Ready
} else {
ActivationState::PreWarming
};
Self {
inner: RwLock::new(ActivationStatus {
state,
pending,
in_progress: None,
completed: vec![],
failed: vec![],
}),
}
}
/// Mark a model as in-progress: remove it from `pending`, set as
/// `in_progress`. Called immediately before `registry.load_model`.
pub async fn start_loading(&self, model_id: &str) {
let mut s = self.inner.write().await;
s.pending.retain(|m| m != model_id);
s.in_progress = Some(model_id.to_string());
}
/// Mark a model as completed: clear `in_progress` (if it matches),
/// append to `completed`.
pub async fn complete_loading(&self, model_id: &str) {
let mut s = self.inner.write().await;
if s.in_progress.as_deref() == Some(model_id) {
s.in_progress = None;
}
s.completed.push(model_id.to_string());
}
/// Mark a model as failed: clear `in_progress` (if it matches),
/// append a `PreWarmFailure` carrying the rendered error chain.
pub async fn fail_loading(&self, model_id: &str, error: &str) {
let mut s = self.inner.write().await;
if s.in_progress.as_deref() == Some(model_id) {
s.in_progress = None;
}
s.failed.push(PreWarmFailure {
model_id: model_id.to_string(),
error: error.to_string(),
});
}
/// Flip the high-level `state` to `Ready` once the pre-warm task
/// is done iterating. Pending should be empty by this point; if a
/// caller bails early it's a stuck activation and the operator
/// will see entries in `pending` even with `state=ready` — that's
/// a useful diagnostic, not an inconsistency to scrub.
pub async fn mark_ready(&self) {
let mut s = self.inner.write().await;
s.state = ActivationState::Ready;
s.in_progress = None;
}
/// Cheap clone of the current state for the `/health` handler.
pub async fn snapshot(&self) -> ActivationStatus {
self.inner.read().await.clone()
}
}

View File

@@ -1,41 +1,23 @@
//! HTTP API handlers for the neuron daemon.
use crate::activation::ActivationTracker;
use crate::harness::HarnessRegistry;
use crate::harness::candle::{CandleHarness, InferenceError};
use crate::harness::preflight::PreflightError;
use crate::health::HealthCache;
use crate::wire::{openai_chat, openai_responses};
use axum::Router;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Json};
use axum::routing::{get, post};
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
use cortex_core::harness::ModelSpec;
use cortex_core::openai::{ChatCompletionRequest, MessageContent};
use cortex_core::responses::{ResponsesRequest, ResponsesUsage};
use futures::stream::{self, StreamExt};
use serde_json::{Value, json};
use std::convert::Infallible;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
/// Shared state for the neuron HTTP server.
pub struct NeuronState {
pub discovery: DiscoveryResponse,
pub health_cache: Arc<HealthCache>,
pub registry: RwLock<HarnessRegistry>,
/// Typed handle to the candle harness for inference routes. Cached at
/// startup so `/v1/chat/completions` doesn't have to hold the registry
/// read lock or perform dyn-Trait dispatch per request.
pub candle: Option<Arc<CandleHarness>>,
/// Activation-time pre-warm progress. Updated by the background
/// `load_default_models` task, read by the `/health` handler.
pub activation: Arc<ActivationTracker>,
}
/// Build the neuron API router.
@@ -47,8 +29,6 @@ pub fn neuron_routes() -> Router<Arc<NeuronState>> {
.route("/models/load", post(load_model))
.route("/models/unload", post(unload_model))
.route("/models/{model_id}/endpoint", get(model_endpoint))
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/responses", post(responses))
}
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
@@ -56,13 +36,7 @@ async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<Discov
}
async fn health_handler(State(state): State<Arc<NeuronState>>) -> Json<HealthResponse> {
// HealthCache owns the uptime + per-device readings; the activation
// tracker owns the pre-warm progress. We compose the response here
// so the cache stays a thin runtime-state cache and doesn't need to
// know about activation lifecycle.
let mut snapshot = state.health_cache.snapshot().await;
snapshot.activation = state.activation.snapshot().await;
Json(snapshot)
Json(state.health_cache.snapshot().await)
}
async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse {
@@ -71,7 +45,7 @@ async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse
Ok(models) => Json(json!(models)).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
@@ -84,52 +58,11 @@ async fn load_model(
let registry = state.registry.read().await;
match registry.load_model(&spec).await {
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
Err(e) => {
// If the underlying failure is a structured preflight
// rejection, surface it as 422 Unprocessable Entity with
// the typed JSON body. The kind/model_id/suggestion/etc.
// fields let cortex (and operators reading the response
// directly) act on the failure without parsing free text.
if let Some(pf) = e.downcast_ref::<PreflightError>() {
tracing::warn!(
model = %spec.model_id,
reason = preflight_kind(pf),
detail = %pf,
"load_model rejected by preflight"
);
return (
StatusCode::UNPROCESSABLE_ENTITY,
Json(json!({ "error": pf })),
)
.into_response();
}
// Log the full anyhow chain server-side so journalctl shows
// the underlying failure (hf-hub timeout, permission denied,
// disk full, etc.) without needing to inspect the HTTP
// response body separately.
tracing::warn!(
model = %spec.model_id,
error = %format!("{e:#}"),
"load_model failed"
);
(
StatusCode::BAD_REQUEST,
Json(json!({"error": format!("{e:#}")})),
)
.into_response()
}
}
}
/// Short kebab-case tag for a preflight failure, used as a structured
/// log field for journalctl-side filtering. Mirrors the same helper in
/// `startup.rs`; duplicated to keep the module surfaces independent.
fn preflight_kind(err: &PreflightError) -> &'static str {
match err {
PreflightError::RepoFetchFailed { .. } => "repo_fetch_failed",
PreflightError::EmptyRepo { .. } => "empty_repo",
PreflightError::TpRequiresSafetensors { .. } => "tp_requires_safetensors",
PreflightError::QuantNotFound { .. } => "quant_not_found",
Err(e) => (
StatusCode::BAD_REQUEST,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
@@ -151,11 +84,7 @@ async fn unload_model(
let registry = state.registry.read().await;
match registry.unload_model(&model_id).await {
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
Err(e) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
}
}
@@ -173,311 +102,3 @@ async fn model_endpoint(
.into_response(),
}
}
/// OpenAI-compatible chat completions. Dispatches to streaming SSE when
/// `stream: true` is set on the request; otherwise returns a single
/// `ChatCompletionResponse`.
async fn chat_completions(
State(state): State<Arc<NeuronState>>,
headers: axum::http::HeaderMap,
Json(req): Json<ChatCompletionRequest>,
) -> impl IntoResponse {
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"error": "candle harness not enabled on this neuron"})),
)
.into_response();
};
// Reasoning-content opt-in. Off by default → naïve clients
// (Zed's commit-message generator, vanilla OpenAI clients)
// never see `<think>` blocks. On when the caller sends
// `x-include-thinking: true` (helexa-acp does this so its
// own ThinkParser keeps working unchanged).
let include_thinking = headers
.get("x-include-thinking")
.and_then(|v| v.to_str().ok())
.map(|s| matches!(s.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
.unwrap_or(false);
let chat_config = openai_chat::ChatProjectionConfig {
include_thinking,
reasoning_markers: None, // filled in from the loaded model inside candle
};
if req.stream.unwrap_or(false) {
match candle.chat_completion_stream_with(req, chat_config).await {
Ok(rx) => {
// Each chunk → one SSE `data: {json}` line. After the
// channel closes, append the OpenAI [DONE] terminator.
let body_stream = ReceiverStream::new(rx).map(|chunk| {
let body = serde_json::to_string(&chunk).unwrap_or_default();
Ok::<_, Infallible>(Event::default().data(body))
});
let done_stream =
stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
Sse::new(body_stream.chain(done_stream))
.keep_alive(KeepAlive::default())
.into_response()
}
Err(InferenceError::ModelNotLoaded(id)) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
"code": "prompt_too_long",
"prompt_len": prompt_len,
"max": max,
})),
)
.into_response(),
Err(InferenceError::InsufficientVram {
free_mb,
required_mb,
}) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": format!(
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
),
"code": "insufficient_vram",
"free_mb": free_mb,
"required_mb": required_mb,
})),
)
.into_response(),
Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
}
} else {
match candle.chat_completion(req).await {
Ok(resp) => Json(resp).into_response(),
Err(InferenceError::ModelNotLoaded(id)) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
"code": "prompt_too_long",
"prompt_len": prompt_len,
"max": max,
})),
)
.into_response(),
Err(InferenceError::InsufficientVram {
free_mb,
required_mb,
}) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": format!(
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
),
"code": "insufficient_vram",
"free_mb": free_mb,
"required_mb": required_mb,
})),
)
.into_response(),
Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
}
}
}
/// OpenAI Responses API (`POST /v1/responses`). Translates the
/// Responses-shaped request into a chat-completions one the candle
/// harness already understands, then re-projects the harness's
/// event stream into the Responses event family.
async fn responses(
State(state): State<Arc<NeuronState>>,
Json(req): Json<ResponsesRequest>,
) -> impl IntoResponse {
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"error": "candle harness not enabled on this neuron"})),
)
.into_response();
};
let stream_requested = req.stream;
let model_id = req.model.clone();
let response_id = mint_response_id();
let message_item_id = mint_message_item_id();
// Translate Responses → chat completions. The only failure
// mode today is `previous_response_id` set, which we reject
// with 400 — stateful conversations need a persistence layer
// we haven't built.
let mut chat_req = match openai_responses::request_to_chat(req) {
Ok(r) => r,
Err(openai_responses::TranslateError::ChainedConversationNotSupported) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "previous_response_id is not supported on this neuron",
"code": "chained_conversation_not_supported"
})),
)
.into_response();
}
};
chat_req.stream = Some(stream_requested);
if stream_requested {
match candle
.responses_stream(chat_req, response_id, message_item_id)
.await
{
Ok(rx) => {
// Each ResponseStreamFrame → one SSE event carrying
// both an event name and JSON data. The Responses
// API doesn't use a `[DONE]` terminator — clients
// see the `response.completed` event as the end of
// the stream.
let body_stream = ReceiverStream::new(rx).map(|frame| {
let body = serde_json::to_string(&frame.data).unwrap_or_else(|_| "{}".into());
Ok::<_, Infallible>(Event::default().event(frame.event_name).data(body))
});
Sse::new(body_stream)
.keep_alive(KeepAlive::default())
.into_response()
}
Err(e) => inference_error_response(e),
}
} else {
// Non-streaming: drive the existing chat completion path
// and translate the result. We don't currently re-tokenise
// to compute usage; the harness returns it via the chat
// response and we pass it through.
match candle.chat_completion(chat_req).await {
Ok(chat_resp) => {
// Extract the assistant text (chat completions
// always emits one choice on the candle path).
let text = chat_resp
.choices
.first()
.map(|c| match &c.message.content {
MessageContent::Text(t) => t.clone(),
MessageContent::Parts(_) => {
// Candle output is always text today;
// a Parts response would be surprising.
// Empty-string fallback is safer than
// a panic.
String::new()
}
})
.unwrap_or_default();
let finish = chat_resp
.choices
.first()
.and_then(|c| c.finish_reason.as_deref())
.map(finish_reason_from_str)
.unwrap_or(crate::wire::FinishReason::Stop);
let usage = chat_resp.usage.as_ref().map(|u| ResponsesUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
total_tokens: u.prompt_tokens + u.completion_tokens,
});
let meta = openai_responses::ResponseMeta {
response_id: mint_response_id(),
created_at: unix_now_secs(),
model_id,
message_item_id: mint_message_item_id(),
};
let _ = chat_resp; // make the borrow-checker happy if `text` consumed it
let resp = openai_responses::build_response(&meta, text, finish, usage);
Json(resp).into_response()
}
Err(e) => inference_error_response(e),
}
}
}
fn finish_reason_from_str(s: &str) -> crate::wire::FinishReason {
use crate::wire::FinishReason;
match s {
"length" => FinishReason::Length,
"tool_calls" => FinishReason::ToolCalls,
_ => FinishReason::Stop,
}
}
/// Centralised mapping from [`InferenceError`] to an HTTP response.
/// Lifted out so the chat-completions and responses handlers stay
/// readable and changes to error-code semantics happen in one spot.
fn inference_error_response(err: InferenceError) -> axum::response::Response {
match err {
InferenceError::ModelNotLoaded(id) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
InferenceError::PromptTooLong { prompt_len, max } => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
"code": "prompt_too_long",
"prompt_len": prompt_len,
"max": max,
})),
)
.into_response(),
InferenceError::InsufficientVram {
free_mb,
required_mb,
} => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": format!(
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
),
"code": "insufficient_vram",
"free_mb": free_mb,
"required_mb": required_mb,
})),
)
.into_response(),
InferenceError::Other(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
}
}
fn mint_response_id() -> String {
format!("resp_{:x}", unix_subsec_nanos())
}
fn mint_message_item_id() -> String {
format!("msg_{:x}", unix_subsec_nanos())
}
fn unix_now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn unix_subsec_nanos() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0)
}

View File

@@ -1,12 +1,12 @@
//! Neuron configuration loaded from neuron.toml.
use cortex_core::harness::{HarnessConfig, ModelSpec};
use cortex_core::harness::HarnessConfig;
use figment::{
Figment,
providers::{Env, Format, Toml},
};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NeuronConfig {
@@ -14,35 +14,10 @@ pub struct NeuronConfig {
pub port: u16,
#[serde(default)]
pub harnesses: Vec<HarnessConfig>,
/// Per-harness configuration. Currently only `candle` is recognised.
#[serde(default)]
pub harness: HarnessSettings,
/// Models to auto-load when the neuron service activates. Each entry
/// is loaded sequentially before the HTTP listener binds. A failure
/// on any single entry logs a warning and proceeds — broken entries
/// don't prevent the rest of the fleet from starting.
#[serde(default)]
pub default_models: Vec<ModelSpec>,
}
/// Settings for individual harness implementations. Each harness owns
/// its own sub-table so users only configure the harnesses they enable.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HarnessSettings {
#[serde(default)]
pub candle: CandleHarnessConfig,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CandleHarnessConfig {
/// HuggingFace cache directory for model weights.
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
#[serde(default)]
pub hf_cache: Option<PathBuf>,
}
fn default_port() -> u16 {
13131
9090
}
impl NeuronConfig {
@@ -58,10 +33,8 @@ impl NeuronConfig {
impl Default for NeuronConfig {
fn default() -> Self {
Self {
port: 13131,
port: 9090,
harnesses: vec![],
harness: HarnessSettings::default(),
default_models: vec![],
}
}
}

View File

@@ -1,84 +0,0 @@
//! FFI declarations for the CUDA kernels in `gdn.cu`.
//!
//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs`
//! covering only the Gated DeltaNet kernels we currently use. Other
//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective
//! scan, etc.) would land here too as we absorb them.
//!
//! All function declarations are MIT-licensed from upstream and
//! unchanged apart from this header.
use std::ffi::c_void;
#[allow(dead_code)]
unsafe extern "C" {
// GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next.
pub(crate) fn gated_delta_rule_recurrence(
q: *const f32,
k: *const f32,
v: *const f32,
g: *const f32,
beta: *const f32,
state: *mut f32,
output: *mut f32,
bh: i32,
seq_len: i32,
k_dim: i32,
v_dim: i32,
stream: i64,
);
/// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks).
pub(crate) fn chunked_gated_delta_rule_recurrence(
q: *const f32,
k: *const f32,
v: *const f32,
g: *const f32,
beta: *const f32,
state: *mut f32,
output: *mut f32,
bh: i32,
seq_len: i32,
k_dim: i32,
v_dim: i32,
stream: i64,
);
pub(crate) fn causal_conv1d_update(
x: *const c_void,
weight: *const c_void,
conv_state: *mut c_void,
output: *mut c_void,
batch_size: i32,
conv_dim: i32,
kernel_size: i32,
dtype: i32,
stream: i64,
);
pub(crate) fn causal_conv1d_full(
x: *const c_void,
weight: *const c_void,
conv_state_out: *mut c_void,
output: *mut c_void,
batch_size: i32,
conv_dim: i32,
seq_len: i32,
kernel_size: i32,
dtype: i32,
stream: i64,
);
pub(crate) fn fused_gdn_gating(
b: *const c_void,
a: *const c_void,
a_log: *const f32,
dt_bias: *const f32,
beta_out: *mut c_void,
g_out: *mut c_void,
total_elements: i32,
num_heads: i32,
dtype: i32,
stream: i64,
);
}

View File

@@ -1,711 +0,0 @@
// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`).
//
// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this
// file are limited to this banner; the kernels are unchanged so a
// diff against upstream stays minimal.
//
// Five kernels exposed via `extern "C"` shims at the bottom:
// - gated_delta_rule_recurrence (per-token decode)
// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill)
// - causal_conv1d_update (single-token conv decode)
// - causal_conv1d_full (multi-token conv prefill)
// - fused_gdn_gating (beta = sigmoid(b);
// g = -exp(A_log) * softplus(a + dt_bias))
#include "cuda_bf16.h"
#include "cuda_fp16.h"
#include <cmath>
#include <cstdint>
#include <cuda_runtime.h>
// ============================================================================
// Kernel 1: gated_delta_rule_recurrence (optimized)
//
// V-tiled recurrence with compile-time K dimension for register residency.
// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of
// state. Shared memory holds k_buf and q_buf (2*BK floats).
//
// Optimizations over naive version:
// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30)
// - #pragma unroll on all k-loops -> full ILP
// - Fused decay+kv_mem pass and fused state_update+output pass
// - __fmaf_rn intrinsics for guaranteed fused multiply-add
// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere
//
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
// state: [BH, K, V] (in/out) output: [BH, S, V]
// ============================================================================
// Optimized kernel: BK known at compile time -> registers + full unrolling
template <int BK, int BV>
__global__ void gated_delta_rule_recurrence_kernel_tiled(
const float *__restrict__ q, // [BH, S, K]
const float *__restrict__ k, // [BH, S, K]
const float *__restrict__ v, // [BH, S, V]
const float *__restrict__ g, // [BH, S]
const float *__restrict__ beta, // [BH, S]
float *__restrict__ state, // [BH, K, V]
float *__restrict__ output, // [BH, S, V]
int seq_len, int v_dim) {
const int v_tile = blockIdx.x; // which V-tile
const int bh = blockIdx.y; // batch*head index
const int tid = threadIdx.x; // thread within tile [0, BV)
const int v_idx = v_tile * BV + tid; // global V index
if (v_idx >= v_dim)
return;
// Pointers for this (batch, head)
const float *q_bh = q + bh * seq_len * BK;
const float *k_bh = k + bh * seq_len * BK;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * BK * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
// Shared memory: k_buf[BK] + q_buf[BK]
__shared__ float k_buf[BK];
__shared__ float q_buf[BK];
// Load state column into registers — BK is compile-time, so this is
// a true register array (not spilled to local memory)
float s[BK];
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
for (int t = 0; t < seq_len; t++) {
// Collaboratively load k_t into shared memory
// BK / BV loads per thread (e.g. 128/64 = 2)
#pragma unroll
for (int j = tid; j < BK; j += BV) {
k_buf[j] = k_bh[t * BK + j];
}
__syncthreads();
// Load scalars for this timestep
float decay = expf(g_bh[t]);
float beta_t = beta_bh[t];
float v_t = v_bh[t * v_dim + v_idx];
// Fused pass 1: decay state + compute kv_mem
float kv_mem = 0.0f;
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] *= decay;
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
}
// Delta rule
float delta = (v_t - kv_mem) * beta_t;
// Collaboratively load q_t into shared memory
#pragma unroll
for (int j = tid; j < BK; j += BV) {
q_buf[j] = q_bh[t * BK + j];
}
__syncthreads();
// Fused pass 2: update state + compute output
float y_t = 0.0f;
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
}
out_bh[t * v_dim + v_idx] = y_t;
__syncthreads();
}
// Write state back
#pragma unroll
for (int j = 0; j < BK; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
// Fallback kernel: runtime k_dim, still V-tiled for occupancy
template <int BV, int MAX_K>
__global__ void gated_delta_rule_recurrence_kernel_fallback(
const float *__restrict__ q, const float *__restrict__ k,
const float *__restrict__ v, const float *__restrict__ g,
const float *__restrict__ beta, float *__restrict__ state,
float *__restrict__ output, int seq_len, int k_dim, int v_dim) {
const int v_tile = blockIdx.x;
const int bh = blockIdx.y;
const int tid = threadIdx.x;
const int v_idx = v_tile * BV + tid;
if (v_idx >= v_dim)
return;
const float *q_bh = q + bh * seq_len * k_dim;
const float *k_bh = k + bh * seq_len * k_dim;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * k_dim * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
extern __shared__ float shared[];
float *k_buf = shared;
float *q_buf = shared + k_dim;
float s[MAX_K];
for (int j = 0; j < k_dim; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
for (int t = 0; t < seq_len; t++) {
for (int j = tid; j < k_dim; j += BV) {
k_buf[j] = k_bh[t * k_dim + j];
}
__syncthreads();
float decay = expf(g_bh[t]);
float beta_t = beta_bh[t];
float v_t = v_bh[t * v_dim + v_idx];
float kv_mem = 0.0f;
for (int j = 0; j < k_dim; j++) {
s[j] *= decay;
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
}
float delta = (v_t - kv_mem) * beta_t;
for (int j = tid; j < k_dim; j += BV) {
q_buf[j] = q_bh[t * k_dim + j];
}
__syncthreads();
float y_t = 0.0f;
for (int j = 0; j < k_dim; j++) {
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
}
out_bh[t * v_dim + v_idx] = y_t;
__syncthreads();
}
for (int j = 0; j < k_dim; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
extern "C" void gated_delta_rule_recurrence(const float *q, const float *k,
const float *v, const float *g,
const float *beta, float *state,
float *output, int bh, int seq_len,
int k_dim, int v_dim,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
if (k_dim == 128) {
// Fast path for Qwen3-Next (k_dim=128)
constexpr int BK = 128;
constexpr int BV = 64;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
v_dim);
} else if (k_dim == 64) {
// Fast path for models with k_dim=64
constexpr int BK = 64;
constexpr int BV = 64;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
v_dim);
} else {
// Fallback for other k_dim values (runtime loop, still V-tiled)
constexpr int BV = 64;
constexpr int MAX_K = 256;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
size_t smem = 2 * k_dim * sizeof(float);
gated_delta_rule_recurrence_kernel_fallback<BV, MAX_K>
<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, k_dim, v_dim);
}
}
// ============================================================================
// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization)
//
// Processes prefill tokens in BT-token chunks instead of one at a time.
// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation,
// forward substitution (triangular solve), output computation, and state
// update.
//
// Same thread model as Kernel 1: one block per (v_tile, batch*head),
// one thread per V-column. Each thread owns BK registers of state.
//
// Shared memory holds:
// k_chunk[BT * BK] -- key vectors for current chunk
// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix
// gcum[BT] -- cumulative sum of g within chunk
// beta_s[BT] -- beta values for chunk
// q_buf[BK] -- q vector (loaded one row at a time)
//
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
// state: [BH, K, V] (in/out) output: [BH, S, V]
// ============================================================================
template <int BT, int BK, int BV>
__global__ void
chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K]
const float *__restrict__ k, // [BH, S, K]
const float *__restrict__ v, // [BH, S, V]
const float *__restrict__ g, // [BH, S]
const float *__restrict__ beta, // [BH, S]
float *__restrict__ state, // [BH, K, V]
float *__restrict__ output, // [BH, S, V]
int seq_len, int v_dim) {
const int v_tile = blockIdx.x;
const int bh = blockIdx.y;
const int tid = threadIdx.x;
const int v_idx = v_tile * BV + tid;
if (v_idx >= v_dim)
return;
const int num_chunks = (seq_len + BT - 1) / BT;
// Pointers for this (batch, head)
const float *q_bh = q + bh * seq_len * BK;
const float *k_bh = k + bh * seq_len * BK;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * BK * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
// Dynamic shared memory layout
extern __shared__ float smem[];
float *k_chunk = smem; // [BT * BK]
float *kk_dot = smem + BT * BK; // [BT * BT]
float *gcum = smem + BT * BK + BT * BT; // [BT]
float *beta_s = gcum + BT; // [BT]
float *q_buf = beta_s + BT; // [BK]
// Load state column into registers
float s[BK];
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
// Per-thread register array for corrected deltas
float delta[BT];
for (int c = 0; c < num_chunks; c++) {
const int chunk_start = c * BT;
const int chunk_len = min(BT, seq_len - chunk_start);
// === Phase 1: Cooperative load of k, beta, g into shared memory ===
for (int t = 0; t < chunk_len; t++) {
for (int j = tid; j < BK; j += BV) {
k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j];
}
}
if (tid < chunk_len) {
beta_s[tid] = beta_bh[chunk_start + tid];
gcum[tid] = g_bh[chunk_start + tid];
}
__syncthreads();
// === Phase 1b: Parallel prefix sum of g (Hillis-Steele) ===
for (int stride = 1; stride < BT; stride <<= 1) {
float prev = 0.0f;
if (tid < chunk_len && (int)tid >= stride)
prev = gcum[tid - stride];
__syncthreads();
if (tid < chunk_len && (int)tid >= stride)
gcum[tid] += prev;
__syncthreads();
}
// === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i ===
// Only lower-triangular entries needed (strictly lower)
for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) {
int i = idx / chunk_len;
int j = idx % chunk_len;
if (j < i) {
float dot = 0.0f;
for (int d = 0; d < BK; d++) {
dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot);
}
kk_dot[i * BT + j] = dot;
}
}
__syncthreads();
// === Phase 3: Forward substitution (per V-column, in registers) ===
// Computes corrected delta values via triangular solve
for (int i = 0; i < chunk_len; i++) {
float v_i = v_bh[(chunk_start + i) * v_dim + v_idx];
float decay_i = expf(gcum[i]);
float beta_i = beta_s[i];
// Inter-chunk contribution: state @ k[i] with decay
float kv_mem = 0.0f;
#pragma unroll
for (int d = 0; d < BK; d++) {
kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem);
}
float rhs = beta_i * (v_i - kv_mem);
// Subtract lower-triangular contributions (intra-chunk)
for (int j = 0; j < i; j++) {
float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]);
rhs -= a_ij * delta[j];
}
delta[i] = rhs;
}
// === Phase 4: Output computation (per V-column) ===
for (int i = 0; i < chunk_len; i++) {
// Cooperatively load q[i] into shared
for (int j = tid; j < BK; j += BV) {
q_buf[j] = q_bh[(chunk_start + i) * BK + j];
}
__syncthreads();
float decay_i = expf(gcum[i]);
// Inter-chunk contribution: q[i] @ (state * decay)
float o_val = 0.0f;
#pragma unroll
for (int d = 0; d < BK; d++) {
o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val);
}
// Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] *
// exp(gcum[i] - gcum[j])
for (int j = 0; j <= i; j++) {
float qk_dot = 0.0f;
for (int d = 0; d < BK; d++) {
qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot);
}
o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]);
}
out_bh[(chunk_start + i) * v_dim + v_idx] = o_val;
__syncthreads();
}
// === Phase 5: State update for next chunk ===
float g_total = gcum[chunk_len - 1];
#pragma unroll
for (int d = 0; d < BK; d++) {
float s_new = s[d] * expf(g_total);
for (int t = 0; t < chunk_len; t++) {
s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]);
}
s[d] = s_new;
}
__syncthreads();
}
// Write final state back
#pragma unroll
for (int j = 0; j < BK; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
extern "C" void chunked_gated_delta_rule_recurrence(
const float *q, const float *k, const float *v, const float *g,
const float *beta, float *state, float *output, int bh, int seq_len,
int k_dim, int v_dim, int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
if (k_dim == 128) {
constexpr int BT = 64;
constexpr int BK = 128;
constexpr int BV = 64;
// Shared memory: BT*BK + BT*BT + BT + BT + BK floats
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
// Request extended shared memory
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem);
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, v_dim);
} else if (k_dim == 64) {
constexpr int BT = 64;
constexpr int BK = 64;
constexpr int BV = 64;
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem);
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, v_dim);
} else {
// Fallback: use the sequential kernel for unsupported k_dim
gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len,
k_dim, v_dim, stream);
}
}
// ============================================================================
// Kernel 2a: causal_conv1d_update (decode path, single step)
//
// Each thread handles one channel: shift conv_state left by 1,
// insert new value, dot product with weight, apply SiLU.
//
// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
// conv_state: [B, conv_dim, kernel_size] (in/out)
// output: [B, conv_dim, 1]
// ============================================================================
template <typename T>
__global__ void causal_conv1d_update_kernel(
const T *__restrict__ x, // [B, conv_dim, 1]
const T *__restrict__ weight, // [conv_dim, kernel_size]
T *__restrict__ conv_state, // [B, conv_dim, kernel_size]
T *__restrict__ output, // [B, conv_dim, 1]
int batch_size, int conv_dim, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int b = blockIdx.y;
if (ch >= conv_dim || b >= batch_size)
return;
// Pointer to this batch/channel's conv state
T *cs = conv_state + (b * conv_dim + ch) * kernel_size;
const T *w = weight + ch * kernel_size;
// Shift state left by 1
for (int i = 0; i < kernel_size - 1; i++) {
cs[i] = cs[i + 1];
}
// Insert new value
cs[kernel_size - 1] = x[b * conv_dim + ch];
// Dot product with weight
float acc = 0.0f;
for (int i = 0; i < kernel_size; i++) {
acc += (float)cs[i] * (float)w[i];
}
// SiLU activation: x * sigmoid(x)
float sig = 1.0f / (1.0f + expf(-acc));
float result = acc * sig;
output[b * conv_dim + ch] = (T)result;
}
extern "C" void causal_conv1d_update(const void *x, const void *weight,
void *conv_state, void *output,
int batch_size, int conv_dim,
int kernel_size, int dtype,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
dim3 block(256);
dim3 grid((conv_dim + 255) / 256, batch_size);
if (dtype == 0) {
// f16
causal_conv1d_update_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)x, (const __half *)weight, (__half *)conv_state,
(__half *)output, batch_size, conv_dim, kernel_size);
} else {
// bf16
causal_conv1d_update_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
(__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size,
conv_dim, kernel_size);
}
}
// ============================================================================
// Kernel 2b: causal_conv1d_full (prefill path)
//
// Each thread handles one (channel, position): causal window with
// zero-padding, dot product with weight, SiLU.
// A second pass writes the conv_state from the last kernel_size positions.
//
// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S]
// ============================================================================
template <typename T>
__global__ void causal_conv1d_full_kernel(
const T *__restrict__ x, // [B, conv_dim, S]
const T *__restrict__ weight, // [conv_dim, kernel_size]
T *__restrict__ output, // [B, conv_dim, S]
int batch_size, int conv_dim, int seq_len, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int pos = blockIdx.y;
const int b = blockIdx.z;
if (ch >= conv_dim || pos >= seq_len || b >= batch_size)
return;
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
const T *w = weight + ch * kernel_size;
// Causal convolution: sum over kernel_size window ending at pos
float acc = 0.0f;
for (int i = 0; i < kernel_size; i++) {
int src_pos = pos - (kernel_size - 1) + i;
float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f;
acc += x_val * (float)w[i];
}
// SiLU
float sig = 1.0f / (1.0f + expf(-acc));
float result = acc * sig;
output[(b * conv_dim + ch) * seq_len + pos] = (T)result;
}
template <typename T>
__global__ void save_conv_state_kernel(
const T *__restrict__ x, // [B, conv_dim, S]
T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size]
int batch_size, int conv_dim, int seq_len, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int b = blockIdx.y;
if (ch >= conv_dim || b >= batch_size)
return;
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size;
// Save last kernel_size positions (zero-pad if seq_len < kernel_size)
int pad = kernel_size - seq_len;
for (int i = 0; i < kernel_size; i++) {
if (i < pad) {
cs[i] = (T)0.0f;
} else {
cs[i] = x_bch[seq_len - kernel_size + i];
}
}
}
extern "C" void causal_conv1d_full(const void *x, const void *weight,
void *conv_state_out, void *output,
int batch_size, int conv_dim, int seq_len,
int kernel_size, int dtype, int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
// Main convolution kernel
dim3 block(256);
dim3 grid((conv_dim + 255) / 256, seq_len, batch_size);
if (dtype == 0) {
causal_conv1d_full_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)x, (const __half *)weight, (__half *)output, batch_size,
conv_dim, seq_len, kernel_size);
// Save conv state
dim3 grid2((conv_dim + 255) / 256, batch_size);
save_conv_state_kernel<__half><<<grid2, block, 0, custream>>>(
(const __half *)x, (__half *)conv_state_out, batch_size, conv_dim,
seq_len, kernel_size);
} else {
causal_conv1d_full_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
(__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size);
dim3 grid2((conv_dim + 255) / 256, batch_size);
save_conv_state_kernel<__nv_bfloat16><<<grid2, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size,
conv_dim, seq_len, kernel_size);
}
}
// ============================================================================
// Kernel 3: fused_gdn_gating
//
// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
// a_log and dt_bias are per-head (broadcast over batch*seq).
//
// b, a: [total] a_log, dt_bias: [num_heads]
// beta_out, g_out: [total]
// ============================================================================
template <typename T>
__global__ void
fused_gdn_gating_kernel(const T *__restrict__ b, // [total]
const T *__restrict__ a, // [total]
const float *__restrict__ a_log, // [num_heads]
const float *__restrict__ dt_bias, // [num_heads]
T *__restrict__ beta_out, // [total]
T *__restrict__ g_out, // [total]
int total_elements, int num_heads) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_elements)
return;
// Head index: elements are laid out as [..., num_heads]
int head_idx = idx % num_heads;
// beta = sigmoid(b)
float b_val = (float)b[idx];
float beta = 1.0f / (1.0f + expf(-b_val));
// g = -exp(a_log) * softplus(a + dt_bias)
float a_val = (float)a[idx];
float a_log_val = a_log[head_idx];
float dt_bias_val = dt_bias[head_idx];
float sp_input = a_val + dt_bias_val;
float softplus_val = logf(1.0f + expf(sp_input));
float g_val = -expf(a_log_val) * softplus_val;
beta_out[idx] = (T)beta;
g_out[idx] = (T)g_val;
}
extern "C" void fused_gdn_gating(const void *b, const void *a,
const float *a_log, const float *dt_bias,
void *beta_out, void *g_out,
int total_elements, int num_heads, int dtype,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
dim3 block(256);
dim3 grid((total_elements + 255) / 256);
if (dtype == 0) {
fused_gdn_gating_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)b, (const __half *)a, a_log, dt_bias,
(__half *)beta_out, (__half *)g_out, total_elements, num_heads);
} else {
fused_gdn_gating_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias,
(__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements,
num_heads);
}
}

View File

@@ -1,486 +0,0 @@
//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`.
//!
//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in
//! this file are this header comment — the FFI path module name is
//! `crate::cuda::ffi`, identical to upstream's layout.
#![allow(clippy::cast_possible_truncation)]
use candle_core::{Result, Tensor};
#[cfg(feature = "cuda")]
use candle_core::DType;
/// CUDA-accelerated gated delta rule recurrence.
///
/// Inputs (all contiguous, f32):
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
/// state: [BH, K, V] (mutated in place)
///
/// Returns: output [BH, S, V]
#[cfg(feature = "cuda")]
pub fn gated_delta_rule_recurrence_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: &mut Tensor,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
let (bh, seq_len, k_dim) = q.dims3()?;
let v_dim = v.dim(2)?;
let dev = q.device().as_cuda_device()?;
let (q_s, q_l) = q.storage_and_layout();
let q_s = match &*q_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("q must be a cuda tensor"),
};
let q_offset = q_l.start_offset();
let (k_s, k_l) = k.storage_and_layout();
let k_s = match &*k_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("k must be a cuda tensor"),
};
let k_offset = k_l.start_offset();
let (v_s, v_l) = v.storage_and_layout();
let v_s = match &*v_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("v must be a cuda tensor"),
};
let v_offset = v_l.start_offset();
let (g_s, g_l) = g.storage_and_layout();
let g_s = match &*g_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("g must be a cuda tensor"),
};
let g_offset = g_l.start_offset();
let (beta_s, beta_l) = beta.storage_and_layout();
let beta_s = match &*beta_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("beta must be a cuda tensor"),
};
let beta_offset = beta_l.start_offset();
let (state_s, state_l) = state.storage_and_layout();
let state_s = match &*state_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("state must be a cuda tensor"),
};
let state_offset = state_l.start_offset();
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::gated_delta_rule_recurrence(
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
bh as i32,
seq_len as i32,
k_dim as i32,
v_dim as i32,
stream,
);
}
// The kernel wrote state in-place via the raw pointer; rewrap
// (state tensor's underlying CudaSlice was modified directly)
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
Ok(Tensor::from((
candle::Storage::Cuda(output_storage),
(bh, seq_len, v_dim),
)))
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn gated_delta_rule_recurrence_cuda(
_q: &Tensor,
_k: &Tensor,
_v: &Tensor,
_g: &Tensor,
_beta: &Tensor,
_state: &mut Tensor,
) -> Result<Tensor> {
candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature")
}
/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization).
///
/// Processes prefill tokens in 64-token chunks instead of one at a time.
/// Same interface as `gated_delta_rule_recurrence_cuda`.
///
/// Inputs (all contiguous, f32):
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
/// state: [BH, K, V] (mutated in place)
///
/// Returns: output [BH, S, V]
#[cfg(feature = "cuda")]
pub fn chunked_gated_delta_rule_recurrence_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: &mut Tensor,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
let (bh, seq_len, k_dim) = q.dims3()?;
let v_dim = v.dim(2)?;
let dev = q.device().as_cuda_device()?;
let (q_s, q_l) = q.storage_and_layout();
let q_s = match &*q_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("q must be a cuda tensor"),
};
let q_offset = q_l.start_offset();
let (k_s, k_l) = k.storage_and_layout();
let k_s = match &*k_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("k must be a cuda tensor"),
};
let k_offset = k_l.start_offset();
let (v_s, v_l) = v.storage_and_layout();
let v_s = match &*v_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("v must be a cuda tensor"),
};
let v_offset = v_l.start_offset();
let (g_s, g_l) = g.storage_and_layout();
let g_s = match &*g_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("g must be a cuda tensor"),
};
let g_offset = g_l.start_offset();
let (beta_s, beta_l) = beta.storage_and_layout();
let beta_s = match &*beta_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("beta must be a cuda tensor"),
};
let beta_offset = beta_l.start_offset();
let (state_s, state_l) = state.storage_and_layout();
let state_s = match &*state_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("state must be a cuda tensor"),
};
let state_offset = state_l.start_offset();
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::chunked_gated_delta_rule_recurrence(
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
bh as i32,
seq_len as i32,
k_dim as i32,
v_dim as i32,
stream,
);
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
Ok(Tensor::from((
candle::Storage::Cuda(output_storage),
(bh, seq_len, v_dim),
)))
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn chunked_gated_delta_rule_recurrence_cuda(
_q: &Tensor,
_k: &Tensor,
_v: &Tensor,
_g: &Tensor,
_beta: &Tensor,
_state: &mut Tensor,
) -> Result<Tensor> {
candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature")
}
/// CUDA-accelerated causal conv1d (both update and full paths).
///
/// For update (is_update=true):
/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update)
/// Returns: (output [B, conv_dim, 1], updated conv_state)
///
/// For full (is_update=false):
/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size])
#[cfg(feature = "cuda")]
pub fn causal_conv1d_cuda(
x: &Tensor,
weight: &Tensor,
conv_state: &Tensor,
kernel_size: usize,
is_update: bool,
) -> Result<(Tensor, Tensor)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use core::ffi::c_void;
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
x: &Tensor,
weight: &Tensor,
conv_state: &Tensor,
kernel_size: usize,
is_update: bool,
dtype_code: i32,
) -> Result<(Tensor, Tensor)> {
let dev = x.device().as_cuda_device()?;
let (batch_size, conv_dim, seq_len) = x.dims3()?;
let (x_s, x_l) = x.storage_and_layout();
let x_s = match &*x_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("x must be a cuda tensor"),
};
let x_offset = x_l.start_offset();
let (w_s, w_l) = weight.storage_and_layout();
let w_s = match &*w_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("weight must be a cuda tensor"),
};
let w_offset = w_l.start_offset();
let stream = dev.cuda_stream().cu_stream() as i64;
if is_update {
// Clone conv_state so the kernel can mutate it in place
let conv_state_new = conv_state.clone();
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim) }?;
// Scope the borrow of conv_state_new so we can move it later
{
let (cs_s, cs_l) = conv_state_new.storage_and_layout();
let cs_s = match &*cs_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("conv_state must be a cuda tensor"),
};
let cs_offset = cs_l.start_offset();
unsafe {
crate::cuda::ffi::causal_conv1d_update(
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void,
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
batch_size as i32,
conv_dim as i32,
kernel_size as i32,
dtype_code,
stream,
);
}
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
let output = Tensor::from((
candle::Storage::Cuda(output_storage),
(batch_size, conv_dim, 1usize),
));
Ok((output, conv_state_new))
} else {
// Full path: allocate new conv_state and output
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * seq_len) }?;
let cs_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * kernel_size) }?;
unsafe {
crate::cuda::ffi::causal_conv1d_full(
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void,
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
batch_size as i32,
conv_dim as i32,
seq_len as i32,
kernel_size as i32,
dtype_code,
stream,
);
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
let output = Tensor::from((
candle::Storage::Cuda(output_storage),
(batch_size, conv_dim, seq_len),
));
let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone());
let new_conv_state = Tensor::from((
candle::Storage::Cuda(cs_storage),
(batch_size, conv_dim, kernel_size),
));
Ok((output, new_conv_state))
}
}
match x.dtype() {
DType::F16 => cuda_fwd::<half::f16>(x, weight, conv_state, kernel_size, is_update, 0),
DType::BF16 => cuda_fwd::<half::bf16>(x, weight, conv_state, kernel_size, is_update, 1),
other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn causal_conv1d_cuda(
_x: &Tensor,
_weight: &Tensor,
_conv_state: &Tensor,
_kernel_size: usize,
_is_update: bool,
) -> Result<(Tensor, Tensor)> {
candle_core::bail!("causal_conv1d_cuda requires the cuda feature")
}
/// CUDA-accelerated fused GDN gating computation.
///
/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
///
/// b, a: [total_elements] in f16/bf16
/// a_log, dt_bias: [num_heads] in f32
///
/// Returns: (beta, g) in original dtype
#[cfg(feature = "cuda")]
pub fn fused_gdn_gating_cuda(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> Result<(Tensor, Tensor)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use core::ffi::c_void;
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
dtype_code: i32,
) -> Result<(Tensor, Tensor)> {
let total_elements = b.elem_count();
let num_heads = a_log.elem_count();
let shape = b.shape().clone();
let dev = b.device().as_cuda_device()?;
let (b_s, b_l) = b.storage_and_layout();
let b_s = match &*b_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("b must be a cuda tensor"),
};
let b_offset = b_l.start_offset();
let (a_s, a_l) = a.storage_and_layout();
let a_s = match &*a_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("a must be a cuda tensor"),
};
let a_offset = a_l.start_offset();
let (alog_s, alog_l) = a_log.storage_and_layout();
let alog_s = match &*alog_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("a_log must be a cuda tensor"),
};
let alog_offset = alog_l.start_offset();
let (dtb_s, dtb_l) = dt_bias.storage_and_layout();
let dtb_s = match &*dtb_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("dt_bias must be a cuda tensor"),
};
let dtb_offset = dtb_l.start_offset();
let beta_buf = unsafe { dev.alloc::<T>(total_elements) }?;
let g_buf = unsafe { dev.alloc::<T>(total_elements) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::fused_gdn_gating(
b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void,
a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void,
alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32,
dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32,
beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void,
g_buf.device_ptr(g_buf.stream()).0 as *mut c_void,
total_elements as i32,
num_heads as i32,
dtype_code,
stream,
);
}
let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone());
let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone()));
let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone());
let g = Tensor::from((candle::Storage::Cuda(g_storage), shape));
Ok((beta, g))
}
match b.dtype() {
DType::F16 => cuda_fwd::<half::f16>(b, a, a_log, dt_bias, 0),
DType::BF16 => cuda_fwd::<half::bf16>(b, a, a_log, dt_bias, 1),
other => candle_core::bail!(
"fused_gdn_gating_cuda only supports f16/bf16, got {:?}",
other
),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn fused_gdn_gating_cuda(
_b: &Tensor,
_a: &Tensor,
_a_log: &Tensor,
_dt_bias: &Tensor,
) -> Result<(Tensor, Tensor)> {
candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature")
}

View File

@@ -1,15 +0,0 @@
//! CUDA kernels and their Rust wrappers.
//!
//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`)
//! inference performance — the Gated DeltaNet kernels ported from
//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu`
//! file alongside this module; `build.rs` compiles them all into a
//! static lib via `cudaforge` and links it under the `cuda` feature.
//!
//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM,
//! etc.) they land here in their own `.cu` + `.rs` pairs.
#[cfg(feature = "cuda")]
pub mod ffi;
#[cfg(feature = "cuda")]
pub mod gdn;

View File

@@ -1,23 +0,0 @@
//! Custom architecture implementations.
//!
//! When candle-transformers ships a model family unchanged
//! (`models::llama`, `models::qwen3`, `models::qwen3_moe`, etc.), the
//! handler in `harness/candle.rs` just wraps the upstream type in a
//! `ModelArch` variant.
//!
//! When candle has nothing for the architecture and we have to write
//! it from scratch — Qwen3-Next / Qwen3.6 (`qwen3_5`) being the
//! motivating example — the implementation lands here, one file per
//! architecture.
//!
//! Each architecture module is expected to expose:
//! - A `Config` type deserialised from the model's `config.json`
//! (some architectures nest the real hyperparams under `text_config`,
//! in which case the module owns the unwrapping).
//! - A `ForCausalLM` struct with `new`, `forward(&mut self, x, offset)
//! -> Result<Tensor>`, and `clear_kv_cache(&mut self)`.
//!
//! TP-aware analogues live in `harness/tp/tp_<family>.rs` and follow
//! the pattern set by `tp_qwen3.rs`.
pub mod qwen3_5;

View File

@@ -1,117 +0,0 @@
//! Qwen3-Next decoder layer.
//!
//! Standard pre-norm transformer block (LN → attention → residual →
//! LN → MLP → residual) where the attention slot dispatches on the
//! per-layer `layer_types[i]` value in the config:
//!
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
//! gate + RoPE + KV cache).
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
//! causal conv + per-head state).
//!
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
//! linear_attention. `full_attention_interval` in the config is a
//! hint; `layer_types` is authoritative.
use anyhow::Result;
use candle_core::{Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use std::sync::Arc;
use super::TextConfig;
use super::full_attn::Qwen3_5Attention;
use super::linear_attn::GatedDeltaNet;
use super::mlp::Qwen3_5MLP;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
/// One of the two attention flavours sitting in a decoder layer's
/// attention slot. Full-attention layers need the rotary table and
/// take an attention mask; linear-attention layers carry their own
/// recurrent state and ignore the mask.
enum AttentionKind {
Full(Qwen3_5Attention),
Linear(GatedDeltaNet),
}
pub struct Qwen3_5DecoderLayer {
input_layernorm: Qwen3_5RmsNorm,
post_attention_layernorm: Qwen3_5RmsNorm,
mlp: Qwen3_5MLP,
attention: AttentionKind,
}
impl Qwen3_5DecoderLayer {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
layer_idx: usize,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let layer_type = cfg
.layer_types
.get(layer_idx)
.map(String::as_str)
.ok_or_else(|| {
anyhow::anyhow!(
"layer_types[{layer_idx}] missing (have {} entries)",
cfg.layer_types.len()
)
})?;
let attention = match layer_type {
"full_attention" => {
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
}
"linear_attention" => {
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
}
other => anyhow::bail!(
"unknown layer_type '{other}' for layer {layer_idx} (expected \
'full_attention' or 'linear_attention')"
),
};
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
let input_layernorm =
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let post_attention_layernorm = Qwen3_5RmsNorm::load(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
input_layernorm,
post_attention_layernorm,
mlp,
attention,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let h = self.input_layernorm.forward(x)?;
let attn_out = match &mut self.attention {
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
// Linear attention ignores attn_mask + offset; its causal
// structure is baked into the recurrent state lifecycle.
AttentionKind::Linear(net) => net.forward(&h)?,
};
let x = (x + attn_out)?;
let h2 = self.post_attention_layernorm.forward(&x)?;
let h2 = self.mlp.forward(&h2)?;
x + h2
}
pub fn clear_kv_cache(&mut self) {
match &mut self.attention {
AttentionKind::Full(attn) => attn.clear_kv_cache(),
AttentionKind::Linear(net) => net.clear_kv_cache(),
}
}
}

View File

@@ -1,179 +0,0 @@
//! Qwen3-Next's `full_attention` layer.
//!
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
//!
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
//! to `num_heads * head_dim * 2`. The second half is reshaped to
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
//! attention output is pointwise-multiplied by this gate before
//! `o_proj`. Effectively a per-head per-position attenuation on
//! the attention output.
//!
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
//! checkpoints expect the `(1 + w)` form.
//!
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
//! `rope::RotaryEmbedding`), and the usual causal mask.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::kv_cache::ConcatKvCache;
use candle_nn::var_builder::ShardedVarBuilder;
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;
use super::TextConfig;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
pub struct Qwen3_5Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: Qwen3_5RmsNorm,
k_norm: Qwen3_5RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary: Arc<RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl Qwen3_5Attention {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let head_dim = cfg.head_dim;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
anyhow::bail!(
"num_attention_heads ({num_heads}) must be a positive multiple of \
num_key_value_heads ({num_kv_heads})"
);
}
let num_kv_groups = num_heads / num_kv_heads;
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
// the gate (see attn_output_gate notes above).
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
let hidden_size = head_dim * num_heads;
let kv_cache = ConcatKvCache::new(2);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size,
rotary,
kv_cache,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
// 1. q_proj — widened output, split into (query, gate).
let q_raw = self
.q_proj
.forward(x)?
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
let q = q_raw.narrow(3, 0, self.head_dim)?;
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
// Flatten the gate's head dim back into hidden_size for the
// post-attention pointwise multiply.
let gate = gate
.contiguous()?
.reshape((b, l, self.num_heads * self.head_dim))?;
// 2. q_norm + k_norm + reshape to (B, H, L, D).
let q = self.q_norm.forward(&q.contiguous()?)?;
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
let k = self
.k_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
let k = self.k_norm.forward(&k.contiguous()?)?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = self
.v_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
// 3. RoPE on q, k.
let (q, k) = self.rotary.apply(&q, &k, offset)?;
// 4. KV cache.
let (k, v) = self.kv_cache.append(&k, &v)?;
// 5. GQA repeat (cheap shape op).
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
// 6. Scaled dot-product + causal mask.
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?; // (B, H, L, D)
// 7. Reshape back, apply the output gate, project.
let ctx = ctx
.transpose(1, 2)?
.contiguous()?
.reshape((b, l, self.hidden_size))?;
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
let gated = (ctx * gate_sig)?;
self.o_proj.forward(&gated)
}
pub fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

View File

@@ -1,793 +0,0 @@
//! Qwen3-Next's `linear_attention` layer: Gated DeltaNet.
//!
//! The recurrent linear-attention block that occupies 3 out of every 4
//! decoder layers in Qwen3.6 (`layer_types[i] == "linear_attention"`).
//! Implemented against the reference Python in
//! `huggingface/transformers/src/transformers/models/qwen3_5/modeling_qwen3_5.py`
//! (class `Qwen3_5GatedDeltaNet`).
//!
//! ## Block structure
//!
//! ```text
//! x ── in_proj_qkv ── transpose ─► (B, conv_dim, L)
//! │
//! ┌──────────────── conv_state ──┤ prepend cached state (decode)
//! ▼
//! depthwise causal Conv1d (k=4) → SiLU
//! │
//! └─ split → q (k_dim), k (k_dim), v (v_dim) ─► per-head reshape
//!
//! x ── in_proj_z ────────────────► z (gate for the output RMSNorm)
//! x ── in_proj_b ── sigmoid ─────► beta (per-head per-token update rate)
//! x ── in_proj_a ── softplus ────► g (decay; see eqn below)
//!
//! g = -exp(A_log) * softplus(a + dt_bias) # discretisation
//! beta = sigmoid(b)
//!
//! (q, k) ─── L2norm ─── delta rule loop ──── core_attn_out
//! (per-token, per-head):
//! state *= exp(g_t)
//! mem = state^T · k_t
//! delta = (v_t - mem) * beta_t
//! state += outer(k_t, delta)
//! out_t = state^T · q_t
//!
//! core_attn_out ── RMSNormGated(z) ── reshape ── out_proj ── y
//! ```
//!
//! ## State
//!
//! Two tensors persist across decode steps:
//! - `conv_state`: `(B, conv_dim, conv_kernel_size)` — left-padded
//! tail of the input to the depthwise conv, so the next causal
//! window has the right left-context.
//! - `recurrent_state`: `(B, num_v_heads, head_k_dim, head_v_dim)` —
//! the delta-rule outer-product memory.
//!
//! Both are cleared via [`GatedDeltaNet::clear_kv_cache`] at the start
//! of every new request.
//!
//! ## Performance note
//!
//! This impl is the **recurrent** delta-rule for both prefill and
//! decode — i.e. the algorithm in `torch_recurrent_gated_delta_rule`.
//! Correctness-first. The chunked algorithm (chunk_size=64) in
//! `torch_chunk_gated_delta_rule` is a perf optimisation for long
//! prefill; can be added later without changing the surface.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
#[cfg(test)]
use super::RopeParameters;
use super::TextConfig;
use super::rmsnorm::{Qwen3_5RmsNormGated, l2norm};
/// Per-rank, per-layer state for the linear-attention block.
///
/// `conv_state` is left-padded with zeros on first use; `recurrent_state`
/// is initialised lazily to zeros once we know the batch size.
#[derive(Default)]
pub struct GatedDeltaNetState {
pub conv_state: Option<Tensor>,
pub recurrent_state: Option<Tensor>,
}
pub struct GatedDeltaNet {
// Projections.
in_proj_qkv: Linear,
in_proj_z: Linear,
in_proj_b: Linear,
in_proj_a: Linear,
out_proj: Linear,
// Depthwise causal Conv1d weight; shape (conv_dim, 1, kernel_size).
// No bias (Python sets bias=False).
conv1d_weight: Tensor,
// Per-head discretisation params.
dt_bias: Tensor,
a_log: Tensor,
// Output norm + gate.
norm: Qwen3_5RmsNormGated,
// Shape hyperparams (cached for forward).
num_v_heads: usize,
num_k_heads: usize,
head_k_dim: usize,
head_v_dim: usize,
key_dim: usize,
value_dim: usize,
conv_dim: usize,
conv_kernel_size: usize,
// Recurrent state held inline. Each request resets via
// `clear_kv_cache`; otherwise the state persists across forwards
// and the per-token offset advances naturally.
state: GatedDeltaNetState,
}
impl GatedDeltaNet {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let num_v_heads = cfg.linear_num_value_heads;
let num_k_heads = cfg.linear_num_key_heads;
let head_k_dim = cfg.linear_key_head_dim;
let head_v_dim = cfg.linear_value_head_dim;
let conv_kernel_size = cfg.linear_conv_kernel_dim;
if num_v_heads == 0 || num_k_heads == 0 {
anyhow::bail!(
"Qwen3-Next linear_num_*_heads must be set; got v={num_v_heads}, k={num_k_heads}"
);
}
if !num_v_heads.is_multiple_of(num_k_heads) {
anyhow::bail!(
"linear_num_value_heads ({num_v_heads}) must be a multiple of \
linear_num_key_heads ({num_k_heads}) for GQA-style head expansion"
);
}
let key_dim = head_k_dim * num_k_heads;
let value_dim = head_v_dim * num_v_heads;
let conv_dim = key_dim * 2 + value_dim;
// ----- Linear projections (all `bias=False` in the reference). -----
let in_proj_qkv = load_linear_no_bias(vb, "in_proj_qkv", cfg.hidden_size, conv_dim)?;
let in_proj_z = load_linear_no_bias(vb, "in_proj_z", cfg.hidden_size, value_dim)?;
let in_proj_b = load_linear_no_bias(vb, "in_proj_b", cfg.hidden_size, num_v_heads)?;
let in_proj_a = load_linear_no_bias(vb, "in_proj_a", cfg.hidden_size, num_v_heads)?;
let out_proj = load_linear_no_bias(vb, "out_proj", value_dim, cfg.hidden_size)?;
// ----- Conv1d weight (depthwise, bias=False). -----
let conv1d_weight = vb
.pp("conv1d")
.get((conv_dim, 1, conv_kernel_size), "weight")
.with_context(|| format!("load '{}/conv1d/weight'", vb.prefix()))?;
// ----- dt_bias + A_log: per-head 1D params. -----
let dt_bias = vb
.get(num_v_heads, "dt_bias")
.with_context(|| format!("load '{}/dt_bias'", vb.prefix()))?;
let a_log = vb
.get(num_v_heads, "A_log")
.with_context(|| format!("load '{}/A_log'", vb.prefix()))?;
// ----- Output gated RMSNorm (per-head_v_dim). -----
let norm = Qwen3_5RmsNormGated::load(&vb.pp("norm"), head_v_dim, cfg.rms_norm_eps)?;
Ok(Self {
in_proj_qkv,
in_proj_z,
in_proj_b,
in_proj_a,
out_proj,
conv1d_weight,
dt_bias,
a_log,
norm,
num_v_heads,
num_k_heads,
head_k_dim,
head_v_dim,
key_dim,
value_dim,
conv_dim,
conv_kernel_size,
state: GatedDeltaNetState::default(),
})
}
pub fn clear_kv_cache(&mut self) {
self.state = GatedDeltaNetState::default();
}
/// `x` shape: `(B, L, hidden_size)`. Returns the same shape.
pub fn forward(&mut self, x: &Tensor) -> candle_core::Result<Tensor> {
let (batch_size, seq_len, _) = x.dims3()?;
let dtype = x.dtype();
let device = x.device().clone();
// ----- Projections. -----
// mixed_qkv: (B, L, conv_dim)
let mixed_qkv = self.in_proj_qkv.forward(x)?;
// (B, conv_dim, L) for the conv1d.
let mixed_qkv_chw = mixed_qkv.transpose(1, 2)?.contiguous()?;
// z: (B, L, value_dim) → (B, L, num_v_heads, head_v_dim)
let z = self.in_proj_z.forward(x)?.reshape((
batch_size,
seq_len,
self.num_v_heads,
self.head_v_dim,
))?;
// b, a: (B, L, num_v_heads)
let b = self.in_proj_b.forward(x)?;
let a = self.in_proj_a.forward(x)?;
// ----- Depthwise causal Conv1d + SiLU (with state continuation). -----
// Dispatches to a cuda kernel that fuses conv1d + silu when
// available; falls back to candle's `conv1d` + `silu` on cpu.
let (conv_out, new_state) = run_causal_conv1d(
&mixed_qkv_chw,
&self.conv1d_weight,
self.state.conv_state.take(),
batch_size,
self.conv_dim,
seq_len,
self.conv_kernel_size,
)?;
self.state.conv_state = Some(new_state);
// Back to (B, L, conv_dim).
let mixed_qkv = conv_out.transpose(1, 2)?.contiguous()?;
// ----- Split into q, k, v. -----
let q = mixed_qkv.narrow(2, 0, self.key_dim)?;
let k = mixed_qkv.narrow(2, self.key_dim, self.key_dim)?;
let v = mixed_qkv.narrow(2, 2 * self.key_dim, self.value_dim)?;
let q = q.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
let k = k.reshape((batch_size, seq_len, self.num_k_heads, self.head_k_dim))?;
let v = v.reshape((batch_size, seq_len, self.num_v_heads, self.head_v_dim))?;
// ----- beta + g (per-head, per-token gates). -----
// Fused on cuda; per-op Rust on cpu. Both paths produce:
// beta = sigmoid(b)
// g = -exp(A_log) * softplus(a + dt_bias)
let (beta, g) = run_fused_gating(&b, &a, &self.a_log, &self.dt_bias)?;
// ----- GQA-style key expansion if num_v_heads > num_k_heads. -----
let (q, k) = if self.num_v_heads > self.num_k_heads {
let rep = self.num_v_heads / self.num_k_heads;
(
repeat_interleave(&q, rep, 2)?,
repeat_interleave(&k, rep, 2)?,
)
} else {
(q, k)
};
// ----- L2-norm on q, k (use_qk_l2norm_in_kernel=True in ref). -----
let q = l2norm(&q, 1e-6)?;
let k = l2norm(&k, 1e-6)?;
// ----- Recurrent delta rule. -----
// Inputs: q, k (B, L, H, D_k); v (B, L, H, D_v); g (B, L, H); beta (B, L, H).
// The reference transposes to (B, H, L, D) before the loop. We
// do the same — it makes per-token indexing trivial.
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D_k)
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?; // (B, H, L, D_v)
let g = g.transpose(1, 2)?.contiguous()?; // (B, H, L)
let beta = beta.transpose(1, 2)?.contiguous()?; // (B, H, L)
// Pre-scale q by 1/sqrt(D_k) once. Everything goes to f32 here
// since the delta rule mixes broadcast_mul ops that candle won't
// accept across mixed dtypes. On the cuda gating path both beta
// and g come back in model dtype; on the cpu path g is already
// f32 — both casts are cheap idempotent ops.
let scale = 1.0_f64 / (self.head_k_dim as f64).sqrt();
let q = (q.to_dtype(candle_core::DType::F32)? * scale)?;
let k = k.to_dtype(candle_core::DType::F32)?;
let v = v.to_dtype(candle_core::DType::F32)?;
let g = g.to_dtype(candle_core::DType::F32)?;
let beta = beta.to_dtype(candle_core::DType::F32)?;
// Initialise the recurrent state from cache or zeros.
let state_init = match self.state.recurrent_state.take() {
Some(s) => s.to_dtype(candle_core::DType::F32)?,
None => Tensor::zeros(
(
batch_size,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
),
candle_core::DType::F32,
&device,
)?,
};
// The delta-rule body: cuda-accelerated `gated_delta_rule_recurrence`
// kernel when we have a cuda device + the kernels are linked in,
// pure-Rust per-token fallback otherwise.
let (core_attn_out, new_state) = run_delta_rule(
&q,
&k,
&v,
&g,
&beta,
state_init,
batch_size,
self.num_v_heads,
seq_len,
self.head_k_dim,
self.head_v_dim,
)?;
// Stash the updated recurrent state for the next call.
self.state.recurrent_state = Some(new_state.to_dtype(dtype)?);
// core_attn_out: (B, H, L, D_v) → (B, L, H, D_v) → (B*L*H, D_v).
let core_attn_out = core_attn_out.transpose(1, 2)?.contiguous()?; // (B, L, H, D_v)
let core_attn_out = core_attn_out.to_dtype(dtype)?;
let core_attn_flat =
core_attn_out.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
let z_flat = z.reshape((batch_size * seq_len * self.num_v_heads, self.head_v_dim))?;
// RMSNormGated: (out * silu(z) * weight) with the norm.
let normed = self.norm.forward(&core_attn_flat, &z_flat)?;
let normed = normed.reshape((batch_size, seq_len, self.num_v_heads * self.head_v_dim))?;
// Output projection: (B, L, value_dim) → (B, L, hidden_size).
self.out_proj.forward(&normed)
}
}
/// Run the per-token delta-rule recurrence.
///
/// `q`, `k`: `(B, H, L, D_k)` (F32). `v`: `(B, H, L, D_v)`. `g`,
/// `beta`: `(B, H, L)`. `state`: `(B, H, D_k, D_v)`.
///
/// Returns `(core_attn_out: (B, H, L, D_v), state: (B, H, D_k, D_v))`,
/// both F32. Caller is responsible for cast back to model dtype.
///
/// Cuda path: dispatches to the `gated_delta_rule_recurrence` kernel
/// ported from `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/gdn.cu`.
/// All five inputs must be cuda f32 tensors. The kernel is V-tiled
/// with compile-time BK; one block per (V-tile, batch*head) and one
/// thread per V-column. Each thread holds BK state floats in
/// registers — eliminates the launch-overhead floor we hit with
/// candle's per-op dispatch (was ~12s/token on Qwen3.6-27B).
///
/// CPU path: pure-Rust per-token loop. Correct, slow.
#[allow(clippy::too_many_arguments)]
pub(crate) fn run_delta_rule(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: Tensor,
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_k_dim: usize,
head_v_dim: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
// Only dispatch to the kernel if the inputs are on a CUDA
// device — CPU tests fall back to the Rust loop below.
if q.device().is_cuda() {
return run_delta_rule_cuda(
q, k, v, g, beta, state, batch_size, num_heads, seq_len, head_k_dim, head_v_dim,
);
}
}
let _ = (batch_size, num_heads, head_k_dim, head_v_dim);
run_delta_rule_rust(q, k, v, g, beta, state, seq_len)
}
/// CUDA path. Flattens (B, H, ...) → (BH, ...) at the kernel boundary
/// (the kernel uses BH = batch*heads as its outer batch axis) and
/// reshapes the kernel's outputs back to (B, H, ...) for the caller.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
fn run_delta_rule_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: Tensor,
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_k_dim: usize,
head_v_dim: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let q_bh = q.flatten(0, 1)?.contiguous()?;
let k_bh = k.flatten(0, 1)?.contiguous()?;
let v_bh = v.flatten(0, 1)?.contiguous()?;
let g_bh = g.flatten(0, 1)?.contiguous()?;
let beta_bh = beta.flatten(0, 1)?.contiguous()?;
let mut state_bh = state.flatten(0, 1)?.contiguous()?;
// For long prefills, the chunked kernel (BT=64) processes a chunk
// of tokens at a time instead of one-by-one — same delta-rule math,
// far fewer block launches. Threshold matches mistralrs.
const CHUNK_THRESHOLD: usize = 64;
let output_bh = if seq_len >= CHUNK_THRESHOLD {
crate::cuda::gdn::chunked_gated_delta_rule_recurrence_cuda(
&q_bh,
&k_bh,
&v_bh,
&g_bh,
&beta_bh,
&mut state_bh,
)?
} else {
crate::cuda::gdn::gated_delta_rule_recurrence_cuda(
&q_bh,
&k_bh,
&v_bh,
&g_bh,
&beta_bh,
&mut state_bh,
)?
};
let core_attn_out = output_bh.reshape((batch_size, num_heads, seq_len, head_v_dim))?;
let new_state = state_bh.reshape((batch_size, num_heads, head_k_dim, head_v_dim))?;
Ok((core_attn_out, new_state))
}
#[allow(clippy::too_many_arguments)]
fn run_delta_rule_rust(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
mut state: Tensor,
seq_len: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
use candle_core::IndexOp;
let mut outputs: Vec<Tensor> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let q_t = q.i((.., .., t, ..))?;
let k_t = k.i((.., .., t, ..))?;
let v_t = v.i((.., .., t, ..))?;
let g_t = g.i((.., .., t))?;
let beta_t = beta.i((.., .., t))?;
let decay = g_t
.exp()?
.unsqueeze(candle_core::D::Minus1)?
.unsqueeze(candle_core::D::Minus1)?;
state = state.broadcast_mul(&decay)?;
let k_col = k_t.unsqueeze(candle_core::D::Minus1)?;
let kv_mem = state.broadcast_mul(&k_col)?.sum(2)?;
let beta_col = beta_t.unsqueeze(candle_core::D::Minus1)?;
let delta = (v_t - kv_mem)?.broadcast_mul(&beta_col)?;
let delta_row = delta.unsqueeze(2)?;
let outer = k_col.broadcast_mul(&delta_row)?;
state = (state + outer)?;
let q_col = q_t.unsqueeze(candle_core::D::Minus1)?;
let out_t = state.broadcast_mul(&q_col)?.sum(2)?;
outputs.push(out_t.unsqueeze(2)?);
}
let core_attn_out = Tensor::cat(&outputs, 2)?; // (B, H, L, D_v)
Ok((core_attn_out, state))
}
/// Depthwise causal conv1d + SiLU, with rolling `conv_state`.
///
/// `x`: `(B, conv_dim, L)` model dtype (f16/bf16 on cuda, anything on cpu).
/// `weight`: `(conv_dim, 1, kernel_size)` model dtype.
/// `conv_state`: `Some((B, conv_dim, kernel_size))` for decode continuation,
/// or `None` for fresh prefill.
///
/// Returns `(conv_out: (B, conv_dim, L), new_conv_state: (B, conv_dim, kernel_size))`.
/// SiLU is baked in.
///
/// Cuda path: dispatches to `causal_conv1d_update` (decode, seq_len=1 with
/// existing state) or `causal_conv1d_full` (prefill / first call), both
/// ported from mistralrs `gdn.cu`. Each kernel fuses the depthwise conv
/// and SiLU activation in one launch — that's ~4× fewer cuda launches per
/// linear-attention layer than the candle `conv1d` + `silu` combo.
///
/// CPU path: the original prepend-narrow-conv1d-silu sequence.
pub(crate) fn run_causal_conv1d(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
if x.device().is_cuda() {
return run_causal_conv1d_cuda(
x,
weight,
conv_state,
batch_size,
conv_dim,
seq_len,
conv_kernel_size,
);
}
}
run_causal_conv1d_rust(
x,
weight,
conv_state,
batch_size,
conv_dim,
seq_len,
conv_kernel_size,
)
}
#[cfg(feature = "cuda")]
fn run_causal_conv1d_cuda(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
// Kernel expects weight as (conv_dim, kernel_size) — squeeze the
// depthwise channel-multiplier dim.
let w = weight.squeeze(1)?.to_dtype(x.dtype())?.contiguous()?;
// Decode path: seq_len == 1 AND we have an existing conv_state.
// Otherwise (prefill or fresh-start decode), use the full path which
// zero-pads on the left internally.
if let Some(cs) = conv_state
&& seq_len == 1
{
let cs = cs.contiguous()?;
let (output, new_conv_state) =
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &cs, conv_kernel_size, true)?;
return Ok((output, new_conv_state));
}
// Prefill / fresh-start: the kernel ignores any prior conv_state and
// zero-pads. If we had a non-zero prior state and >1 input tokens
// (multi-turn continuation), we'd need to fall back to Rust. Match
// mistralrs's behaviour: fresh prefill always.
let device = x.device().clone();
let zeros_cs = Tensor::zeros((batch_size, conv_dim, conv_kernel_size), x.dtype(), &device)?;
let (output, new_conv_state) =
crate::cuda::gdn::causal_conv1d_cuda(x, &w, &zeros_cs, conv_kernel_size, false)?;
Ok((output, new_conv_state))
}
/// Fused GDN gating: computes `beta = sigmoid(b)` and
/// `g = -exp(a_log) * softplus(a + dt_bias)` together.
///
/// `b`, `a`: `(B, L, num_heads)` model dtype.
/// `a_log`, `dt_bias`: `(num_heads,)` model dtype (cast to f32 internally).
///
/// Returns `(beta, g)` both in model dtype on the cuda path, both in f32
/// on the cpu fallback. The caller casts to f32 before the delta rule.
///
/// Cuda path: dispatches to `fused_gdn_gating_cuda` — one kernel
/// replaces sigmoid + neg(exp) + softplus + broadcast_mul (≈10 candle
/// launches per layer).
pub(crate) fn run_fused_gating(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> candle_core::Result<(Tensor, Tensor)> {
#[cfg(feature = "cuda")]
{
if b.device().is_cuda() {
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?.contiguous()?;
let dt_bias_f32 = dt_bias.to_dtype(candle_core::DType::F32)?.contiguous()?;
return crate::cuda::gdn::fused_gdn_gating_cuda(b, a, &a_log_f32, &dt_bias_f32);
}
}
run_fused_gating_rust(b, a, a_log, dt_bias)
}
fn run_fused_gating_rust(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> candle_core::Result<(Tensor, Tensor)> {
let beta = candle_nn::ops::sigmoid(b)?;
let a_log_f32 = a_log.to_dtype(candle_core::DType::F32)?;
let neg_a_exp = a_log_f32.exp()?.neg()?;
let dt_b_f32 = dt_bias.to_dtype(candle_core::DType::F32)?;
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
let a_plus_dt = a_f32.broadcast_add(&dt_b_f32)?;
let softplus_val = softplus(&a_plus_dt)?;
let neg_a_exp_b = neg_a_exp.unsqueeze(0)?.unsqueeze(0)?;
let g = neg_a_exp_b.broadcast_mul(&softplus_val)?;
Ok((beta, g))
}
fn run_causal_conv1d_rust(
x: &Tensor,
weight: &Tensor,
conv_state: Option<Tensor>,
batch_size: usize,
conv_dim: usize,
seq_len: usize,
conv_kernel_size: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let dtype = x.dtype();
let device = x.device().clone();
let prepended = match &conv_state {
Some(prev) => Tensor::cat(&[prev, x], 2)?,
None => x.clone(),
};
let prep_len = prepended.dims()[2];
let new_state = if prep_len >= conv_kernel_size {
prepended.narrow(2, prep_len - conv_kernel_size, conv_kernel_size)?
} else {
let pad = Tensor::zeros(
(batch_size, conv_dim, conv_kernel_size - prep_len),
dtype,
&device,
)?;
Tensor::cat(&[&pad, &prepended], 2)?
};
let conv_out = prepended.conv1d(weight, conv_kernel_size - 1, 1, 1, conv_dim)?;
let conv_out = conv_out.narrow(2, 0, prep_len)?;
let conv_out = candle_nn::ops::silu(&conv_out)?;
let conv_out = conv_out.narrow(2, prep_len - seq_len, seq_len)?;
Ok((conv_out, new_state))
}
/// Load a no-bias linear from the ShardedVarBuilder. Weight shape is
/// the standard `[out, in]` order.
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}
/// Numerically-stable `softplus(x) = ln(1 + exp(x))`. Matches PyTorch's
/// `F.softplus` default (beta=1, threshold=20: for large positive x,
/// returns x as-is to avoid overflow in the exp).
pub(crate) fn softplus(x: &Tensor) -> candle_core::Result<Tensor> {
let threshold = 20.0_f64;
let big = x.ge(threshold)?; // Tensor<u8> mask
let safe = x.minimum(&x.affine(0.0, 0.0)?.affine(0.0, threshold)?)?; // min(x, threshold)
let small = ((safe.exp()? + 1.0_f64)?).log()?;
// Select x where big, else small.
big.where_cond(x, &small)
}
/// `repeat_interleave` along a single dim. Candle has no built-in for
/// this; emulate with unsqueeze + expand + reshape.
pub(crate) fn repeat_interleave(
x: &Tensor,
repeats: usize,
dim: usize,
) -> candle_core::Result<Tensor> {
if repeats == 1 {
return Ok(x.clone());
}
let mut shape = x.dims().to_vec();
let orig = shape[dim];
shape.insert(dim + 1, repeats);
let mut expanded_shape = shape.clone();
expanded_shape[dim + 1] = repeats;
let x = x.unsqueeze(dim + 1)?;
let x = x.expand(expanded_shape)?;
let mut out_shape = x.dims().to_vec();
out_shape.remove(dim + 1);
out_shape[dim] = orig * repeats;
x.contiguous()?.reshape(out_shape)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
#[test]
fn softplus_small_x() {
// softplus(0) = ln(2) ≈ 0.6931
let x = Tensor::new(&[0.0_f32], &Device::Cpu).unwrap();
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
assert!((out[0] - 2.0_f32.ln()).abs() < 1e-4);
}
#[test]
fn softplus_large_x_returns_x() {
// For x = 30, softplus(x) ≈ x (the threshold branch).
let x = Tensor::new(&[30.0_f32], &Device::Cpu).unwrap();
let out: Vec<f32> = softplus(&x).unwrap().to_vec1().unwrap();
assert!((out[0] - 30.0).abs() < 1e-4);
}
#[test]
fn repeat_interleave_doubles_dim() {
let x = Tensor::new(&[[1.0_f32, 2.0], [3.0, 4.0]], &Device::Cpu).unwrap(); // shape (2, 2)
let out = repeat_interleave(&x, 2, 1).unwrap(); // each col duplicated
let v: Vec<Vec<f32>> = out.to_vec2().unwrap();
// Row 0: 1, 1, 2, 2
// Row 1: 3, 3, 4, 4
assert_eq!(v[0], vec![1.0, 1.0, 2.0, 2.0]);
assert_eq!(v[1], vec![3.0, 3.0, 4.0, 4.0]);
}
/// Sanity: the recurrent path produces a finite tensor of the right
/// shape on tiny dimensions. Doesn't validate numerical correctness
/// against the Python reference — that would need a fixed-weight
/// fixture to compare against. Catches structural mistakes
/// (broadcasting shapes, off-by-one slices) early.
#[test]
fn forward_smoke_with_tiny_dimensions() {
let dev = Device::Cpu;
let dtype = DType::F32;
let (b, l) = (1, 3);
let cfg = TextConfig {
vocab_size: 100,
hidden_size: 16,
intermediate_size: 32,
num_hidden_layers: 1,
num_attention_heads: 4,
num_key_value_heads: 1,
head_dim: 4,
max_position_embeddings: 32,
rope_parameters: RopeParameters {
rope_theta: 10000.0,
partial_rotary_factor: 1.0,
rope_type: None,
},
rms_norm_eps: 1e-6,
tie_word_embeddings: false,
attn_output_gate: true,
layer_types: vec!["linear_attention".into()],
full_attention_interval: Some(4),
hidden_act: "silu".into(),
linear_num_value_heads: 4,
linear_num_key_heads: 2,
linear_key_head_dim: 4,
linear_value_head_dim: 4,
linear_conv_kernel_dim: 4,
};
// Build a synthetic VarBuilder with all-zeros weights.
// Easier path: skip the load and construct GatedDeltaNet
// manually by hand-rolling the Linear/Tensor inputs.
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &dev).unwrap();
let key_dim = cfg.linear_key_head_dim * cfg.linear_num_key_heads;
let value_dim = cfg.linear_value_head_dim * cfg.linear_num_value_heads;
let conv_dim = key_dim * 2 + value_dim;
let mut net = GatedDeltaNet {
in_proj_qkv: Linear::new(zeros(&[conv_dim, cfg.hidden_size]), None),
in_proj_z: Linear::new(zeros(&[value_dim, cfg.hidden_size]), None),
in_proj_b: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
in_proj_a: Linear::new(zeros(&[cfg.linear_num_value_heads, cfg.hidden_size]), None),
out_proj: Linear::new(zeros(&[cfg.hidden_size, value_dim]), None),
conv1d_weight: zeros(&[conv_dim, 1, cfg.linear_conv_kernel_dim]),
dt_bias: zeros(&[cfg.linear_num_value_heads]),
a_log: zeros(&[cfg.linear_num_value_heads]),
norm: {
let weight = Tensor::ones(&[cfg.linear_value_head_dim], dtype, &dev).unwrap();
Qwen3_5RmsNormGated::from_weight(weight, cfg.rms_norm_eps)
},
num_v_heads: cfg.linear_num_value_heads,
num_k_heads: cfg.linear_num_key_heads,
head_k_dim: cfg.linear_key_head_dim,
head_v_dim: cfg.linear_value_head_dim,
key_dim,
value_dim,
conv_dim,
conv_kernel_size: cfg.linear_conv_kernel_dim,
state: GatedDeltaNetState::default(),
};
let x = Tensor::ones(&[b, l, cfg.hidden_size], dtype, &dev).unwrap();
let y = net.forward(&x).unwrap();
assert_eq!(y.dims(), &[b, l, cfg.hidden_size]);
// All zero weights → output should be zero. Confirms no NaN/Inf
// poisoning from the f32 promotions.
let v: Vec<f32> = y.flatten_all().unwrap().to_vec1().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
}

View File

@@ -1,53 +0,0 @@
//! SwiGLU MLP block for Qwen3-Next.
//!
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
//! no bias on any of the three projections.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use super::TextConfig;
pub struct Qwen3_5MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl Qwen3_5MLP {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let i = cfg.intermediate_size;
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
})
}
}
impl Module for Qwen3_5MLP {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
let rhs = self.up_proj.forward(x)?;
self.down_proj.forward(&(lhs * rhs)?)
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

View File

@@ -1,397 +0,0 @@
//! Qwen3-Next (`model_type = "qwen3_5"`) architecture — Qwen3.6's
//! upstream architecture revision.
//!
//! ## Naming
//!
//! The model release this targets is `Qwen/Qwen3.6-*` but the
//! architecture name in HuggingFace's `config.json` is `qwen3_5`.
//! mistralrs calls the same architecture `qwen3_next`; that label
//! ages poorly the next time Qwen ship a new arch, so we key on the
//! canonical `qwen3_5` from the model's own config.
//!
//! ## Status
//!
//! **Single-GPU dense path is real**. Both attention flavours
//! (`full_attention` with the output-gated GQA causal attention and
//! `linear_attention` with the Gated DeltaNet recurrent block) are
//! implemented. The model loads from upstream safetensors via the
//! existing `load_arch_dense` dispatch and runs forward end to end.
//!
//! Numerical correctness vs the reference Python is **not yet
//! validated** — the structural code path is right, weight tensor
//! names match the upstream layout, shapes flow through cleanly, but
//! the Tbilisi probe (and any other downstream test) is the next
//! step. Likely places a bug would surface:
//! - Per-rank vs per-token-position offsets in the recurrent delta
//! rule (`linear_attn.rs`).
//! - Off-by-one in the conv state continuation across decode steps.
//! - RoPE phase mismatch from MRoPE simplification (we treat the
//! three position grids as collapsed, which is correct only for
//! text-only inference).
//!
//! ## Submodules
//!
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
//! `l2norm` helper.
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
//! rotate-half).
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
//! widening on `q_proj`.
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
//! delta rule → RMSNormGated → out_proj).
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
//! two attention flavours per layer index.
//!
//! ## Open work
//!
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
//! Sharding strategy diverges by layer type:
//! - Full-attention layers: column-parallel q/k/v (including the
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
//! `tp_qwen3.rs`.
//! - Linear-attention layers: the recurrent state is per-V-head, so
//! V-head-dimension sharding works cleanly — split `num_v_heads`
//! across ranks (`num_v_heads / world_size` per rank), shard
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
//! `dt_bias` per-head params shard with the heads.
//!
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
//! per-token recurrent path for prefill too — correct but O(L).
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
//! prefill substantially with no surface change.
use anyhow::{Context, Result};
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::Embedding;
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use serde::Deserialize;
use std::sync::Arc;
pub mod decoder;
pub mod full_attn;
pub mod linear_attn;
pub mod mlp;
pub mod rmsnorm;
pub mod rope;
use decoder::Qwen3_5DecoderLayer;
use rmsnorm::Qwen3_5RmsNorm;
use rope::RotaryEmbedding;
/// `model_type` we deserialise from `config.json`. Const so the
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
/// magic strings.
pub const MODEL_TYPE: &str = "qwen3_5";
/// Top-level shape of Qwen3-Next's `config.json`. The real
/// hyperparameters live in `text_config`; the rest is multimodal /
/// tokeniser glue we don't need for the language-model forward.
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
/// Always `"qwen3_5"` for this architecture. Kept on the struct
/// so the (eventual) dispatch / logging code can show it without
/// re-parsing the JSON.
pub model_type: String,
/// The text-side hyperparameters. Everything we actually need.
pub text_config: TextConfig,
}
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
/// but with the extras Qwen3-Next adds (`attn_output_gate`,
/// `layer_types`, `full_attention_interval`, larger `head_dim`).
#[derive(Debug, Clone, Deserialize)]
pub struct TextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub max_position_embeddings: usize,
/// Nested RoPE settings. Qwen3-Next puts `rope_theta` and
/// `partial_rotary_factor` inside this block rather than at the
/// top level — important because the partial rotary means only
/// `head_dim * partial_rotary_factor` dims get RoPE applied (the
/// rest pass through unchanged).
pub rope_parameters: RopeParameters,
pub rms_norm_eps: f64,
#[serde(default)]
pub tie_word_embeddings: bool,
/// New in Qwen3-Next: a sigmoid gate multiplied into the attention
/// output before the o_proj. The Python reference applies it
/// pointwise after softmax+matmul.
#[serde(default)]
pub attn_output_gate: bool,
/// One entry per decoder layer; values are `"full_attention"` or
/// `"linear_attention"`. Length must equal `num_hidden_layers`.
/// `full_attention_interval` is a derived hint (every 4th layer
/// by default) — `layer_types` is authoritative.
#[serde(default)]
pub layer_types: Vec<String>,
/// Hint for the layer-type pattern (defaults to 4). Kept for
/// logging / validation; the forward dispatches on `layer_types`.
#[serde(default)]
pub full_attention_interval: Option<usize>,
/// Hidden activation (`"silu"` for Qwen3-Next). Used by the MLP
/// and the linear-attention conv1d.
#[serde(default = "default_hidden_act")]
pub hidden_act: String,
// --- Gated DeltaNet (linear-attention) hyperparams -----------------
/// Per-layer linear-attention V-head count (Qwen3.6-27B: 48).
/// More V-heads than K-heads is fine — query/key get
/// `repeat_interleave`'d to match before the delta rule.
#[serde(default)]
pub linear_num_value_heads: usize,
/// Per-layer linear-attention K-head count (Qwen3.6-27B: 16).
#[serde(default)]
pub linear_num_key_heads: usize,
/// Per-head key dimension for the linear-attention path
/// (Qwen3.6-27B: 128). Separate from `head_dim` which the
/// full-attention layers use.
#[serde(default)]
pub linear_key_head_dim: usize,
/// Per-head value dimension for the linear-attention path
/// (Qwen3.6-27B: 128).
#[serde(default)]
pub linear_value_head_dim: usize,
/// Causal Conv1d kernel size used before the delta rule
/// (Qwen3.6-27B: 4).
#[serde(default)]
pub linear_conv_kernel_dim: usize,
}
fn default_hidden_act() -> String {
"silu".into()
}
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
/// `mrope_section` and `mrope_interleaved` are accepted via the
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat
/// MRoPE as plain RoPE for text-only inference (the three position
/// grids carry identical ids when there's no vision input, so the
/// interleaving is a no-op).
#[derive(Debug, Clone, Deserialize)]
pub struct RopeParameters {
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
#[serde(default = "default_rope_theta")]
pub rope_theta: f64,
/// Fraction of `head_dim` that gets the rotation applied. The
/// remaining `head_dim * (1 - partial_rotary_factor)` dims pass
/// through unchanged. Qwen3.6 / Qwen3.5: 0.25.
#[serde(default = "default_partial_rotary_factor")]
pub partial_rotary_factor: f32,
/// `"default"` for the standard inv_freq RoPE; other values (e.g.
/// `"linear"`, `"dynamic"`) are upstream-supported but not yet
/// implemented here.
#[serde(default)]
pub rope_type: Option<String>,
}
fn default_rope_theta() -> f64 {
10_000.0
}
fn default_partial_rotary_factor() -> f32 {
1.0
}
/// Qwen3-Next base transformer (embedding + decoder stack + final
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
/// loaded handle.
pub struct Qwen3_5Model {
embed_tokens: Embedding,
layers: Vec<Qwen3_5DecoderLayer>,
norm: Qwen3_5RmsNorm,
device: Device,
dtype: DType,
}
impl Qwen3_5Model {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
// Qwen3-Next is a multimodal architecture whose text core lives
// under `model.language_model.*` — sibling to `model.visual.*`
// (the vision tower) and to top-level `lm_head` / `mtp.*`.
// Every text-side tensor in the safetensors files is under
// this prefix; we ignore the vision and MTP weights for
// language-model inference.
let text_vb = vb.pp("model.language_model");
let embed_vb = text_vb.pp("embed_tokens");
let embed_weight = embed_vb
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
if cfg.layer_types.len() != cfg.num_hidden_layers {
anyhow::bail!(
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
got {}",
cfg.num_hidden_layers,
cfg.layer_types.len()
);
}
let vb_l = text_vb.pp("layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(Qwen3_5DecoderLayer::load(
cfg,
rotary.clone(),
i,
&vb_l.pp(i),
)?);
}
let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
pub fn embed_weight(&self) -> &Tensor {
self.embed_tokens.embeddings()
}
pub fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
// Causal mask only needed for L > 1 prefill; full-attention
// layers consume it via broadcast_add. Linear-attention layers
// ignore the mask.
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
pub struct Qwen3_5ForCausalLM {
base: Qwen3_5Model,
lm_head: Linear,
}
impl Qwen3_5ForCausalLM {
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
let cfg = &config.text_config;
let base = Qwen3_5Model::load(cfg, &vb)?;
let lm_head = if cfg.tie_word_embeddings {
Linear::new(base.embed_weight().clone(), None)
} else {
let weight = vb
.pp("lm_head")
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
Linear::new(weight, None)
};
Ok(Self { base, lm_head })
}
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
/// the last position, shape `(B, 1, vocab_size)` — same contract
/// as `qwen3::ModelForCausalLM::forward` so the harness's
/// `squeeze_to_vocab` helper handles both uniformly.
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self.base.forward(input, offset)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Confirms we can deserialise the real upstream config shape.
/// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to
/// the fields the architecture cares about. Note `rope_theta` and
/// `partial_rotary_factor` are nested under `rope_parameters` —
/// Qwen3-Next does NOT have a top-level `rope_theta`.
#[test]
fn config_deserialises_the_real_qwen3_6_shape() {
let raw = r#"{
"architectures": ["Qwen3_5ForConditionalGeneration"],
"model_type": "qwen3_5",
"image_token_id": 248056,
"language_model_only": false,
"text_config": {
"vocab_size": 248064,
"hidden_size": 5120,
"intermediate_size": 17408,
"num_hidden_layers": 64,
"num_attention_heads": 64,
"num_key_value_heads": 8,
"head_dim": 256,
"max_position_embeddings": 32768,
"rope_parameters": {
"mrope_interleaved": true,
"mrope_section": [11, 11, 10],
"partial_rotary_factor": 0.25,
"rope_theta": 10000000,
"rope_type": "default"
},
"rms_norm_eps": 1e-6,
"tie_word_embeddings": false,
"attn_output_gate": true,
"full_attention_interval": 4,
"layer_types": [
"linear_attention", "linear_attention",
"linear_attention", "full_attention"
]
}
}"#;
let cfg: Config = serde_json::from_str(raw).expect("parse Qwen3.6 config");
assert_eq!(cfg.model_type, "qwen3_5");
assert_eq!(cfg.text_config.hidden_size, 5120);
assert_eq!(cfg.text_config.head_dim, 256);
assert!(cfg.text_config.attn_output_gate);
assert_eq!(cfg.text_config.full_attention_interval, Some(4));
assert_eq!(cfg.text_config.layer_types.len(), 4);
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
}
}

View File

@@ -1,161 +0,0 @@
//! Norm primitives for Qwen3-Next.
//!
//! Two reasons we can't reuse `candle_nn::RmsNorm` directly:
//!
//! 1. **`(1.0 + weight)` scaling.** Qwen3-Next's `Qwen3_5RMSNorm`
//! initialises `weight` to zeros and applies `(1.0 + weight)` to
//! the normalised vector. `candle_nn::RmsNorm` applies `weight`
//! directly. The two are equivalent only when the operator has
//! pre-shifted the weights — the upstream checkpoints have not. See
//! `huggingface/transformers#29402` for the upstream PR that
//! introduced the `(1 + w)` form to recover from the zero-init.
//!
//! 2. **Gated variant.** The linear-attention layer post-normalises
//! its output by an RMSNorm *gated* with a per-element SiLU on
//! a sibling `z` projection — fused for numerical reasons (the
//! norm's float32 promotion has to happen before the SiLU
//! multiply). Not a single existing candle op.
//!
//! Both ops accept inputs in any compute dtype; promotion to f32 for
//! the variance calculation matches the Python reference.
use anyhow::{Context, Result};
use candle_core::{D, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
/// L2-normalise along the last dim with a small epsilon. Matches the
/// `l2norm` helper in `transformers/models/qwen3_5/modeling_qwen3_5.py`
/// — `x * rsqrt(sum(x*x) + eps)`. The linear-attention path uses this
/// on Q and K before the delta rule when
/// `use_qk_l2norm_in_kernel=True` (which Qwen3-Next always sets).
pub fn l2norm(x: &Tensor, eps: f32) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let sq = x_f32.sqr()?;
let sum = sq.sum_keepdim(D::Minus1)?;
let inv = (sum + eps as f64)?.sqrt()?.recip()?;
x_f32.broadcast_mul(&inv)?.to_dtype(dtype)
}
/// Qwen3-Next's RMSNorm. Stores the raw weight tensor; forward applies
/// `(1.0 + weight) * x_normed`.
pub struct Qwen3_5RmsNorm {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNorm {
/// Load `weight` from the ShardedVarBuilder. `vb` should already be
/// `.pp(...)`-ed to the norm's tensor prefix.
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
pub fn size(&self) -> usize {
self.size
}
}
impl Module for Qwen3_5RmsNorm {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
// Promote weight to f32 and shift by 1.0 *before* multiplying.
// Doing the (1 + w) operation in fp16 lands at -inf for the
// bottom-of-range weights at load time.
let w_f32 = self.weight.to_dtype(candle_core::DType::F32)?;
let scale = (w_f32 + 1.0_f64)?;
normed.broadcast_mul(&scale)?.to_dtype(dtype)
}
}
/// Gated RMSNorm used at the tail of `Qwen3_5GatedDeltaNet`. Equivalent
/// to `x_normed * weight * silu(gate)` but with both the norm and the
/// gate evaluated in float32 to avoid mid-pipeline underflow.
///
/// Note: unlike `Qwen3_5RmsNorm`, this variant matches the Python
/// reference's `Qwen3_5RMSNormGated` which uses `weight` directly (not
/// `1.0 + weight`).
pub struct Qwen3_5RmsNormGated {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNormGated {
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
/// Direct constructor — used by unit tests that build a layer
/// without going through a VarBuilder.
#[cfg(test)]
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
let size = weight.dims()[0];
Self {
weight,
eps: eps as f32,
size,
}
}
pub fn size(&self) -> usize {
self.size
}
/// `x` and `gate` share the same last-dim shape (`size`).
pub fn forward(&self, x: &Tensor, gate: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
let w = self.weight.to_dtype(candle_core::DType::F32)?;
let out = normed.broadcast_mul(&w)?;
// SiLU on the float32 gate, multiply back into the normed
// tensor, then cast to the model dtype.
let g = gate.to_dtype(candle_core::DType::F32)?;
let silu_gate = candle_nn::ops::silu(&g)?;
(out * silu_gate)?.to_dtype(dtype)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn l2norm_matches_hand_calc() {
let x = Tensor::new(&[3.0_f32, 4.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
// |x| = 5, so x/|x| = [0.6, 0.8] (eps is tiny).
assert!((v[0] - 0.6).abs() < 1e-4);
assert!((v[1] - 0.8).abs() < 1e-4);
}
#[test]
fn l2norm_zero_vector_is_safe_via_epsilon() {
let x = Tensor::new(&[0.0_f32, 0.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
}

View File

@@ -1,114 +0,0 @@
//! Rotary position embedding for Qwen3-Next's full-attention layers.
//!
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
//! reference Python — three position grids interleaved per
//! `mrope_section`. For text-only inference all three grids carry the
//! same position ids and the interleave is a no-op, so this module
//! implements the plain (non-mrope) flavour: the standard inv_freq
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
//!
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
//! head dim is negated and swapped into the first). The reference
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
//! `rope_slow` is the matching helper.
use anyhow::Result;
use candle_core::{DType, Device, Tensor};
use super::TextConfig;
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
/// Number of dims at the head's leading edge that the rotation
/// covers. The remaining `head_dim - rotary_dim` dims pass through
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
/// for `head_dim = 256` only 64 dims rotate.
rotary_dim: usize,
head_dim: usize,
}
impl RotaryEmbedding {
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
let head_dim = cfg.head_dim;
let rope = &cfg.rope_parameters;
let rotary_dim = (head_dim as f32 * rope.partial_rotary_factor) as usize;
if !rotary_dim.is_multiple_of(2) {
anyhow::bail!(
"rotary_dim = head_dim * partial_rotary_factor = {head_dim} * {} = {rotary_dim} \
must be even (cos/sin are paired)",
rope.partial_rotary_factor
);
}
if rotary_dim == 0 {
anyhow::bail!(
"rotary_dim = 0 (partial_rotary_factor = {} too small)",
rope.partial_rotary_factor
);
}
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<f32> = (0..rotary_dim)
.step_by(2)
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
.collect();
let n = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
rotary_dim,
head_dim,
})
}
/// Apply RoPE to q, k.
///
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
/// into the cached cos/sin table — the position of the first token
/// in the current step.
///
/// When `rotary_dim < head_dim` the rotation is applied only to the
/// first `rotary_dim` dims of each head; the tail passes through
/// unchanged (matches the reference Python's
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`).
pub fn apply(
&self,
q: &Tensor,
k: &Tensor,
offset: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, head_dim_in) = q.dims4()?;
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
if self.rotary_dim == self.head_dim {
// Full rotation.
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
} else {
// Partial rotation: narrow → rotate → cat the untouched tail.
let tail = self.head_dim - self.rotary_dim;
let q_rot = q
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
.contiguous()?;
let q_pass = q.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
let k_rot = k
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
.contiguous()?;
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?;
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?;
let q_embed =
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
let k_embed =
Tensor::cat(&[&k_rotated, &k_pass.contiguous()?], candle_core::D::Minus1)?;
Ok((q_embed, k_embed))
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,392 +0,0 @@
//! Chat-template rendering for the model-supplied Jinja templates
//! HuggingFace tokenizers ship in `tokenizer_config.json`.
//!
//! ## Background
//!
//! Every modern open-weight model bundles a `chat_template` field
//! in its `tokenizer_config.json` — a Jinja2 template string that
//! converts a sequence of `{role, content}` messages into the
//! exact prompt the model was trained on. Examples:
//!
//! - Qwen3-Coder: `<|im_start|>{role}\n{content}<|im_end|>\n…`
//! with conditional `enable_thinking` handling that injects an
//! empty `<think>\n\n</think>` block when set false.
//! - DeepSeek-R1: similar im_start framing with different special-
//! token names.
//! - Mistral / Magistral: a `[INST]` / `[/INST]` framing.
//! - Claude / Llama: another shape again.
//!
//! Rendering the model's own template is the only way to get the
//! *exact* prompt format the model was trained on plus the
//! model-specific kwargs (`enable_thinking`, `tools`, …) without
//! hardcoding per-model logic. The alternative — neuron's previous
//! `format_qwen3_prompt` — was a hardcoded Qwen3 ChatML glue that
//! ignored kwargs entirely.
//!
//! ## Scope
//!
//! This module is request-side only: it builds the prompt string
//! the tokenizer ingests before inference. The reasoning- and
//! tool-call-marker token routing (issues #6, #8) is response-side
//! and stays in `wire::openai_chat` / the streaming inference
//! loops.
//!
//! ## Fallback
//!
//! When the model's `tokenizer_config.json` is missing, doesn't
//! parse, lacks a `chat_template`, or renders an error, the caller
//! falls back to `format_qwen3_prompt`. The
//! `NEURON_USE_CHAT_TEMPLATE=false` env var is a global kill
//! switch — if a deploy goes sideways and the renderer is to
//! blame, an operator can flip the env and restart neuron without
//! shipping a new build.
use anyhow::{Context, Result};
use cortex_core::openai::{ChatMessage, MessageContent};
use minijinja::Environment;
use serde_json::Value;
use std::path::Path;
/// Environment variable that, when set to `false`/`0`/`no`,
/// forces every model to skip its `chat_template` and fall back
/// to `format_qwen3_prompt`. Default (unset) is "use chat
/// templates where available".
pub const KILL_SWITCH_ENV: &str = "NEURON_USE_CHAT_TEMPLATE";
/// Read the global kill switch. `true` means chat templates are
/// enabled; `false` forces the fallback path everywhere.
pub fn chat_templates_enabled() -> bool {
match std::env::var(KILL_SWITCH_ENV).ok().as_deref() {
Some(s) => !matches!(
s.trim().to_ascii_lowercase().as_str(),
"false" | "0" | "no" | "off"
),
None => true,
}
}
/// Convenience: probe for `tokenizer_config.json` in the same
/// directory the tokenizer was loaded from. Both files come from
/// the same HuggingFace snapshot in the hf-hub cache, so the
/// sibling path is reliable.
pub fn load_chat_template_alongside(tokenizer_json_path: &Path) -> Option<String> {
let parent = tokenizer_json_path.parent()?;
let config_path = parent.join("tokenizer_config.json");
load_chat_template_from(&config_path)
}
/// Best-effort load of `chat_template` from a HuggingFace
/// `tokenizer_config.json`. Returns `None` when the file is
/// absent, doesn't parse, or lacks the `chat_template` field —
/// in all of those cases the caller falls back to
/// `format_qwen3_prompt`. Warnings are logged so an operator can
/// see why the fallback fired.
pub fn load_chat_template_from(path: &Path) -> Option<String> {
let text = match std::fs::read_to_string(path) {
Ok(t) => t,
Err(e) => {
tracing::debug!(
path = %path.display(),
error = %e,
"chat_template: tokenizer_config.json absent or unreadable; falling back"
);
return None;
}
};
let value: Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
path = %path.display(),
error = %e,
"chat_template: tokenizer_config.json failed to parse; falling back"
);
return None;
}
};
// Some tokenizer_config.json files carry `chat_template` as an
// array of `{name, template}` objects (multi-template models —
// tool-use variant, default variant). For now we pick the first
// entry; future iterations could honour a name hint.
match value.get("chat_template") {
Some(Value::String(s)) => Some(s.clone()),
Some(Value::Array(arr)) => {
for entry in arr {
if let Some(t) = entry.get("template").and_then(|v| v.as_str()) {
return Some(t.to_string());
}
}
tracing::warn!(
path = %path.display(),
"chat_template: array form had no usable template entry; falling back"
);
None
}
_ => None,
}
}
/// Render the chat template into the prompt the model expects.
///
/// `template` is the raw Jinja string from `tokenizer_config.json`.
/// `messages` is the conversation in order. `kwargs` is the
/// `chat_template_kwargs` object the client supplied on the
/// request (or `Value::Null` when absent). The function expands
/// the kwargs into the Jinja context alongside the standard
/// `messages` and `add_generation_prompt` variables HF templates
/// expect.
///
/// `tools` is the request's `tools` array (or `Value::Null`).
/// Some chat templates iterate it to emit native tool definitions
/// (Qwen3-Coder's tool-use template, Mistral's [TOOL_DEFINITIONS]
/// frame). We forward whatever the client sent without
/// interpretation.
pub fn render_chat_template(
template: &str,
messages: &[ChatMessage],
tools: &Value,
kwargs: &Value,
) -> Result<String> {
let mut env = Environment::new();
// Compile the template against a fixed name so error messages
// surface "chat_template" rather than `<template>`.
env.add_template("chat_template", template)
.context("compile chat_template")?;
let tmpl = env.get_template("chat_template").unwrap();
// Convert our internal ChatMessage shape into the
// `[{role, content}]` shape HF templates iterate. Text content
// becomes a string; Parts becomes an array of content blocks.
// The HF templates handle both shapes via `content is string`
// checks or content-array iteration.
let messages_json: Vec<Value> = messages
.iter()
.map(|m| {
let content_value = match &m.content {
MessageContent::Text(s) => Value::String(s.clone()),
MessageContent::Parts(parts) => Value::Array(parts.clone()),
};
let mut obj = serde_json::Map::new();
obj.insert("role".into(), Value::String(m.role.clone()));
obj.insert("content".into(), content_value);
// Forward extras (e.g. tool_calls on assistant turns,
// tool_call_id on tool result turns). HF templates that
// need them read e.g. `message.tool_calls`.
if let Value::Object(extras) = &m.extra {
for (k, v) in extras {
obj.insert(k.clone(), v.clone());
}
}
Value::Object(obj)
})
.collect();
// Build the kwargs context. Add base bindings the template
// expects (`messages`, `add_generation_prompt`, `tools`) plus
// anything the caller passed in `chat_template_kwargs`. Caller
// kwargs override the defaults so `add_generation_prompt: false`
// from the request actually wins.
let mut ctx_map = serde_json::Map::new();
ctx_map.insert("messages".into(), Value::Array(messages_json));
ctx_map.insert("add_generation_prompt".into(), Value::Bool(true));
if !tools.is_null() {
ctx_map.insert("tools".into(), tools.clone());
}
if let Value::Object(kwargs_obj) = kwargs {
for (k, v) in kwargs_obj {
ctx_map.insert(k.clone(), v.clone());
}
}
// `Template::render` takes any Serialize value; serde_json's
// `Value` implements it natively, so we pass the assembled
// context object directly without going through the
// `context!` macro (which expects minijinja-native values).
tmpl.render(Value::Object(ctx_map))
.context("render chat_template")
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn user_msg(text: &str) -> ChatMessage {
ChatMessage {
role: "user".into(),
content: MessageContent::Text(text.into()),
extra: Value::Object(Default::default()),
}
}
fn assistant_msg(text: &str) -> ChatMessage {
ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(text.into()),
extra: Value::Object(Default::default()),
}
}
/// Minimal Qwen3-style template — enough surface to confirm
/// our renderer threads role + content correctly without
/// loading a real model's tokenizer_config.json.
const QWEN3_LIKE: &str = "{%- for message in messages -%}\
<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n\
{%- endfor -%}\
{%- if add_generation_prompt -%}<|im_start|>assistant\n{%- endif -%}";
#[test]
fn renders_basic_conversation() {
let prompt = render_chat_template(
QWEN3_LIKE,
&[user_msg("hello"), assistant_msg("hi"), user_msg("bye")],
&Value::Null,
&Value::Null,
)
.unwrap();
// Structural assertions — the exact whitespace produced
// by a given template is a Jinja-trim concern that varies
// per real chat_template. What matters is that every
// turn's role + content thread through in order, and that
// the generation cue lands at the end.
assert!(
prompt.contains("<|im_start|>user\nhello<|im_end|>"),
"first user turn missing: {prompt}"
);
assert!(
prompt.contains("<|im_start|>assistant\nhi<|im_end|>"),
"assistant turn missing: {prompt}"
);
assert!(
prompt.contains("<|im_start|>user\nbye<|im_end|>"),
"second user turn missing: {prompt}"
);
assert!(
prompt.ends_with("<|im_start|>assistant")
|| prompt.ends_with("<|im_start|>assistant\n"),
"generation cue missing at end: {prompt}"
);
}
#[test]
fn kwargs_are_threaded_into_template_context() {
// Replica of Qwen3's enable_thinking branch in
// simplified form. When the kwarg is false, the model's
// template injects an empty `<think>...</think>` block
// before the generation cue — pre-filling the model's
// reasoning slot with "no thinking" so the model emits
// the answer directly.
let template = "{%- if enable_thinking is defined and enable_thinking is false -%}\
NO_THINK\
{%- else -%}\
THINK_OK\
{%- endif -%}";
let r_disabled = render_chat_template(
template,
&[],
&Value::Null,
&json!({ "enable_thinking": false }),
)
.unwrap();
assert_eq!(r_disabled, "NO_THINK");
let r_default = render_chat_template(template, &[], &Value::Null, &Value::Null).unwrap();
assert_eq!(r_default, "THINK_OK");
}
#[test]
fn missing_template_field_returns_none() {
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-missing-field.json");
std::fs::write(&tmp, r#"{"some_other_field": 1}"#).unwrap();
assert!(load_chat_template_from(&tmp).is_none());
let _ = std::fs::remove_file(tmp);
}
#[test]
fn load_template_from_string_field() {
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-string.json");
std::fs::write(
&tmp,
r#"{"chat_template": "hello {{ messages[0].content }}"}"#,
)
.unwrap();
let t = load_chat_template_from(&tmp).expect("template loaded");
assert!(t.contains("messages[0].content"));
let _ = std::fs::remove_file(tmp);
}
#[test]
fn load_template_from_array_form() {
// Some HF models ship `chat_template` as `[{name, template}, ...]`.
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-array.json");
std::fs::write(
&tmp,
r#"{"chat_template": [{"name": "default", "template": "ARR"}]}"#,
)
.unwrap();
let t = load_chat_template_from(&tmp).expect("template loaded");
assert_eq!(t, "ARR");
let _ = std::fs::remove_file(tmp);
}
#[test]
fn missing_file_returns_none_quietly() {
let absent = std::path::PathBuf::from("/definitely/not/a/real/path.json");
assert!(load_chat_template_from(&absent).is_none());
}
#[test]
fn unparseable_returns_none() {
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-garbage.json");
std::fs::write(&tmp, b"{not valid json").unwrap();
assert!(load_chat_template_from(&tmp).is_none());
let _ = std::fs::remove_file(tmp);
}
#[test]
fn kill_switch_recognises_truthy_falsy_values() {
// Test against the actual env var so callers see the
// same behaviour as production. Serialise via a
// mutex — see path_util.rs for the pattern.
use std::sync::Mutex;
static LOCK: Mutex<()> = Mutex::new(());
let _g = LOCK.lock().unwrap();
let prior = std::env::var(KILL_SWITCH_ENV).ok();
unsafe {
std::env::remove_var(KILL_SWITCH_ENV);
}
assert!(chat_templates_enabled());
for value in ["false", "0", "no", "off", "FALSE", " no "] {
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
assert!(!chat_templates_enabled(), "value {value:?} should disable");
}
for value in ["true", "1", "yes", ""] {
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
assert!(chat_templates_enabled(), "value {value:?} should enable");
}
unsafe {
match prior {
Some(p) => std::env::set_var(KILL_SWITCH_ENV, p),
None => std::env::remove_var(KILL_SWITCH_ENV),
}
}
}
#[test]
fn message_extras_thread_through_for_tool_calls() {
// HF templates read assistant.tool_calls and tool
// turns' tool_call_id. Confirm our extras flatten into
// the message object the template iterates.
let mut extras = serde_json::Map::new();
extras.insert(
"tool_calls".into(),
json!([{"id": "t1", "function": {"name": "x", "arguments": "{}"}}]),
);
let msg = ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(String::new()),
extra: Value::Object(extras),
};
let template = "{{ messages[0].tool_calls[0].id }}";
let rendered = render_chat_template(template, &[msg], &Value::Null, &Value::Null).unwrap();
assert_eq!(rendered, "t1");
}
}

View File

@@ -1,810 +0,0 @@
//! Synchronous dispatch loop running on the device worker thread.
//!
//! `run()` is the thread's entry point. It binds the CUDA context for
//! its device on startup, then pulls `Job`s off the channel one at a
//! time and runs the corresponding handler. The handlers are
//! synchronous by design — the only async on this thread is the
//! one-line `oneshot::Sender::send` call to ship the reply back, which
//! is non-blocking.
//!
//! Phase 2 handles QueryVram, TransferIn, DropArch, ClearKv,
//! ForwardLogits, Shutdown. Phase 3 will add the TP variants
//! (NcclInit, NcclSanity, TpLoadShard, TpForward, TpClearKv) and the
//! ARCH model state in this state slab will gain a companion
//! `tp_models: HashMap<TpHandle, Box<TpLeaderModel>>`.
use crate::harness::candle::ModelArch;
#[cfg(feature = "cuda")]
use crate::harness::device_worker::jobs::TpHandle;
use crate::harness::device_worker::jobs::{ArchHandle, Job};
#[cfg(feature = "cuda")]
use crate::harness::tp::TpLeaderModel;
use crate::harness::tp::nccl_state::NcclState;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::Receiver;
/// Per-thread state owned by the worker. On CUDA builds the `Arc<CudaContext>`
/// is created and bound at thread startup; on CPU builds the struct
/// is mostly empty.
struct DeviceWorkerState {
#[allow(dead_code)]
device_index: u32,
/// Candle `Device` constructed at startup. Used by handlers (e.g.
/// `ForwardLogits`) to build input tensors against the right
/// device. Falls back to `Device::Cpu` if CUDA init fails.
device: candle_core::Device,
/// Boxed `ModelArch` slab. Indexed by an opaque `ArchHandle` minted
/// by `TransferIn`. The Box means the entry's address is stable
/// across HashMap rehashes (relevant only when we later hand out
/// `&mut ModelArch` references — for Phase 2 every handler runs
/// `&mut` via `get_mut`, no long-lived borrows).
models: HashMap<ArchHandle, Box<ModelArch>>,
/// Counter for minting fresh `ArchHandle`s. Each `TransferIn`
/// increments and returns the new value. Wraps at u64::MAX after
/// ~10^19 model loads — not a practical concern.
next_handle: u64,
/// Leader's NCCL state. Populated by `Job::NcclInit`; the
/// underlying `Comm`'s libnccl handle lives bound to this thread
/// for its entire lifetime. Subprocess workers maintain their own
/// `NcclState` in their own processes — that's not visible from
/// here.
#[allow(dead_code)] // Read only via methods on NcclState
nccl: NcclState,
/// TP leader model slab. Same lifecycle as `models`; separate
/// namespace so `ArchHandle` and `TpHandle` can't collide.
#[cfg(feature = "cuda")]
tp_models: HashMap<TpHandle, Box<TpLeaderModel>>,
/// Counter for minting fresh `TpHandle`s.
#[cfg(feature = "cuda")]
next_tp_handle: u64,
#[cfg(feature = "cuda")]
#[allow(dead_code)]
/// `None` only if `CudaContext::new()` failed — in that case the
/// thread still runs so the handle's lifecycle stays uniform, but
/// every job that touches CUDA falls through to a zero reply with
/// a log warning.
ctx: Option<Arc<candle_core::cuda::cudarc::driver::CudaContext>>,
}
/// Worker thread entry point. Runs until `Job::Shutdown` arrives or
/// the channel sender is dropped (which happens when the last
/// `DeviceWorkerHandle` `Arc` is dropped without an explicit
/// `shutdown()`).
pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool>) {
let mut state = init_state(device_index);
tracing::info!(device_index, "device worker started");
while let Ok(job) = rx.recv() {
// Shutdown is processed unconditionally so a poisoned worker
// still exits when asked. Matching by reference first so we
// can fall through to the consume-match below.
if matches!(&job, Job::Shutdown) {
break;
}
if poisoned.load(Ordering::Acquire) {
// Drain-only mode: reply with a poisoned error without
// touching CUDA. Phase 1/2 never set the flag from the
// dispatch loop itself (no driver errors classified yet),
// but tests use `DeviceWorkerHandle::set_poisoned()` to
// simulate this state.
drain_poisoned(job, device_index);
continue;
}
match job {
Job::QueryVram { reply } => {
let result = query_vram(&state);
// If the caller dropped its receiver (request cancelled,
// gateway timed out) the send fails — fine, we just
// discard the reply.
let _ = reply.send(result);
}
Job::LoadGguf {
gguf_path,
model_id,
reply,
} => {
let result = load_gguf_inner(&state.device, &gguf_path, &model_id)
.map(|arch| insert_arch(&mut state, Box::new(arch)));
let _ = reply.send(result);
}
Job::LoadDense {
config_path,
safetensors_paths,
model_id,
reply,
} => {
let result =
load_dense_inner(&state.device, &config_path, &safetensors_paths, &model_id)
.map(|arch| insert_arch(&mut state, Box::new(arch)));
let _ = reply.send(result);
}
Job::DropArch { handle, reply } => {
let removed = state.models.remove(&handle);
let was_present = removed.is_some();
// Explicit drop on this thread — runs the Box<ModelArch>
// Drop with the CUDA context bound here, which frees
// all device tensors on the right context. The Drop is
// implicit on the `removed` value going out of scope at
// the end of the arm; calling drop() explicitly just
// makes the intent visible.
drop(removed);
tracing::debug!(
device_index,
handle = handle.0,
was_present,
slab_size = state.models.len(),
"device worker: model dropped"
);
let _ = reply.send(());
}
Job::ClearKv { handle, reply } => {
let result = match state.models.get_mut(&handle) {
Some(arch) => arch.clear_kv_cache(),
None => Err(anyhow::anyhow!("ClearKv: no model for handle {}", handle.0)),
};
if result.is_ok() {
trim_device_pool(&state);
}
let _ = reply.send(result);
}
Job::ForwardLogits {
handle,
tokens,
offset,
reply,
} => {
let result = forward_logits(&mut state, handle, &tokens, offset);
let _ = reply.send(result);
}
Job::NcclInit {
cfg,
comm_id_hex,
reply,
} => {
let resp = state.nccl.init(cfg, &comm_id_hex);
let _ = reply.send(resp);
}
Job::NcclSanity { reply } => {
let resp = state.nccl.sanity_check();
let _ = reply.send(resp);
}
#[cfg(feature = "cuda")]
Job::TpLoadShard {
model_id,
config_json,
safetensors_paths,
dtype,
quant,
world_size,
reply,
} => {
let result = tp_load_shard_inner(
&mut state,
&model_id,
&config_json,
&safetensors_paths,
dtype,
quant.as_deref(),
world_size,
);
let _ = reply.send(result);
}
#[cfg(feature = "cuda")]
Job::DropTp { handle, reply } => {
let removed = state.tp_models.remove(&handle);
let was_present = removed.is_some();
drop(removed);
tracing::debug!(
device_index,
tp_handle = handle.0,
was_present,
slab_size = state.tp_models.len(),
"device worker: TP model dropped"
);
let _ = reply.send(());
}
#[cfg(feature = "cuda")]
Job::TpClearKv { handle, reply } => {
let result = match state.tp_models.get_mut(&handle) {
Some(model) => {
model.clear_kv_cache();
Ok(())
}
None => Err(anyhow::anyhow!(
"TpClearKv: no TP model for handle {}",
handle.0
)),
};
if result.is_ok() {
trim_device_pool(&state);
}
let _ = reply.send(result);
}
#[cfg(feature = "cuda")]
Job::TpForwardLogits {
handle,
tokens,
offset,
reply,
} => {
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
let _ = reply.send(result);
}
// Handled by the matches!() check above; reaching here
// means a Shutdown slipped past which is a bug.
Job::Shutdown => unreachable!("Shutdown should break above"),
}
}
#[cfg(feature = "cuda")]
let tp_slab_size = state.tp_models.len();
#[cfg(not(feature = "cuda"))]
let tp_slab_size = 0_usize;
tracing::info!(
device_index,
slab_size = state.models.len(),
tp_slab_size,
"device worker exiting; dropping remaining models"
);
// Drops every model in the slab on this thread before the function
// returns. Critical for CUDA tensors: dropping on a thread that
// doesn't have the context bound is UB. Phase 2 still runs Drop
// via the slab going out of scope, which is correct as long as no
// pre-poisoned state lurks in here — see the poisoned-mode
// semantics in mod.rs for the Phase 3+ refinement.
}
fn init_state(device_index: u32) -> DeviceWorkerState {
#[cfg(feature = "cuda")]
{
use candle_core::cuda::cudarc::driver::CudaContext;
// Construct a candle Device first — cudarc returns the
// primary context for this index on subsequent calls, so
// CudaContext::new and Device::new_cuda end up sharing state.
let (device, ctx) = match candle_core::Device::new_cuda(device_index as usize) {
Ok(device) => match CudaContext::new(device_index as usize) {
Ok(ctx) => {
if let Err(e) = ctx.bind_to_thread() {
tracing::warn!(
device_index,
error = ?e,
"device worker: bind_to_thread failed; \
operations will still rebind per-call"
);
} else {
tracing::info!(device_index, "device worker bound CUDA context");
}
(device, Some(ctx))
}
Err(e) => {
tracing::warn!(
device_index,
error = ?e,
"device worker: CudaContext::new failed; \
vram queries will return (0, 0), forward will error"
);
(device, None)
}
},
Err(e) => {
tracing::warn!(
device_index,
error = %e,
"device worker: Device::new_cuda failed; falling back to CPU device"
);
(candle_core::Device::Cpu, None)
}
};
DeviceWorkerState {
device_index,
device,
models: HashMap::new(),
next_handle: 1,
nccl: NcclState::new(),
tp_models: HashMap::new(),
next_tp_handle: 1,
ctx,
}
}
#[cfg(not(feature = "cuda"))]
{
DeviceWorkerState {
device_index,
device: candle_core::Device::Cpu,
models: HashMap::new(),
next_handle: 1,
nccl: NcclState::new(),
}
}
}
#[cfg(feature = "cuda")]
fn query_vram(state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
use candle_core::cuda::cudarc::driver::result;
if state.ctx.is_none() {
return Ok((0, 0));
}
// The context was bound in init_state. cudarc's `mem_get_info`
// reads from the current context on the calling thread; since we
// bound on startup and we never spawn child threads from this
// worker, the binding holds.
match result::mem_get_info() {
Ok((free, total)) => Ok((
(free / (1024 * 1024)) as u64,
(total / (1024 * 1024)) as u64,
)),
Err(e) => Err(anyhow::anyhow!("mem_get_info: {e:?}")),
}
}
#[cfg(not(feature = "cuda"))]
fn query_vram(_state: &DeviceWorkerState) -> anyhow::Result<(u64, u64)> {
Ok((0, 0))
}
/// Force cudarc's stream-ordered memory pool to release every block it
/// is holding back to the system. After `ConcatKvCache::reset()` drops
/// its tensors, the underlying `CudaSlice::drop` calls `cuMemFreeAsync`,
/// which returns the blocks to the device's default mempool but not to
/// the OS — `mem_get_info` still reports them as used. The next
/// request's prefill then sees a falsely-small free pool and either
/// OOMs or trips cuBLAS into `CUBLAS_STATUS_INTERNAL_ERROR`.
///
/// Calling `cuMemPoolTrimTo(pool, 0)` after each `clear_kv_cache`
/// returns those blocks. We synchronize first so any pending
/// `cuMemFreeAsync` operations have settled. Failures are non-fatal:
/// the pool may not exist on legacy drivers, or a transient driver
/// error may prevent the trim — neither breaks correctness, the next
/// request just sees a less-recovered free pool.
#[cfg(feature = "cuda")]
fn trim_device_pool(state: &DeviceWorkerState) {
use candle_core::cuda::cudarc::driver::result::{device, mem_pool};
let Some(ctx) = state.ctx.as_ref() else {
return;
};
let (before_free, _) = match query_vram(state) {
Ok(v) => v,
Err(_) => (0, 0),
};
if let Err(e) = ctx.synchronize() {
tracing::debug!(
device_index = state.device_index,
error = ?e,
"trim_device_pool: synchronize failed; skipping trim"
);
return;
}
let dev = ctx.cu_device();
let pool = match unsafe { device::get_default_mem_pool(dev) } {
Ok(p) => p,
Err(e) => {
tracing::debug!(
device_index = state.device_index,
error = ?e,
"trim_device_pool: get_default_mem_pool failed"
);
return;
}
};
if let Err(e) = unsafe { mem_pool::trim_to(pool, 0) } {
tracing::debug!(
device_index = state.device_index,
error = ?e,
"trim_device_pool: cuMemPoolTrimTo failed"
);
return;
}
let (after_free, _) = match query_vram(state) {
Ok(v) => v,
Err(_) => (0, 0),
};
let freed_mb = after_free.saturating_sub(before_free);
tracing::debug!(
device_index = state.device_index,
before_free_mb = before_free,
after_free_mb = after_free,
freed_mb,
"trim_device_pool: trimmed pool"
);
}
#[cfg(not(feature = "cuda"))]
fn trim_device_pool(_state: &DeviceWorkerState) {}
/// Insert a freshly-built `ModelArch` into the slab and mint a fresh
/// `ArchHandle`. Used by both `LoadGguf` and `LoadDense` dispatch
/// handlers — they differ only in *how* the arch is built; the
/// post-construction bookkeeping is identical.
fn insert_arch(state: &mut DeviceWorkerState, arch: Box<ModelArch>) -> ArchHandle {
let handle = ArchHandle(state.next_handle);
state.next_handle = state.next_handle.wrapping_add(1);
state.models.insert(handle, arch);
tracing::debug!(
device_index = state.device_index,
handle = handle.0,
slab_size = state.models.len(),
"device worker: model inserted"
);
handle
}
/// Load a GGUF (pre-quantized) model on the worker thread. Pulled
/// verbatim from the spawn_blocking closure that used to live in
/// `CandleHarness::load_arch_gguf`; the only change is that `device`
/// is now `state.device` (the worker's permanently-bound device).
fn load_gguf_inner(
device: &candle_core::Device,
gguf_path: &std::path::Path,
model_id: &str,
) -> anyhow::Result<ModelArch> {
use anyhow::Context;
use candle_core::DType;
use candle_core::quantized::gguf_file;
use candle_transformers::models::quantized_llama::ModelWeights as QuantizedLlamaWeights;
use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights;
use candle_transformers::models::quantized_qwen3_moe::GGUFQWenMoE;
tracing::info!(model = %model_id, path = ?gguf_path, "loading GGUF");
let mut file = std::fs::File::open(gguf_path).context("open GGUF file")?;
let content =
gguf_file::Content::read(&mut file).map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?;
let architecture = content
.metadata
.get("general.architecture")
.and_then(|v| v.to_string().ok().cloned())
.unwrap_or_default();
tracing::info!(architecture = %architecture, "GGUF architecture");
// The `general.architecture` GGUF metadata key follows
// llama.cpp conventions (lowercase, no underscores in some
// cases) — `qwen3moe`, not `qwen3_moe`.
match architecture.as_str() {
"qwen3" => {
let weights = QuantizedQwen3Weights::from_gguf(content, &mut file, device)
.map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?;
Ok(ModelArch::Qwen3Quantized(weights))
}
"qwen3moe" => {
// GGUFQWenMoE takes an explicit compute dtype alongside
// the device — F16 matches the GGUF weights' typical
// accumulation precision and gives the best tokens/sec on
// consumer cards.
let weights = GGUFQWenMoE::from_gguf(content, &mut file, device, DType::F16)
.map_err(|e| anyhow::anyhow!("from_gguf qwen3_moe: {e}"))?;
Ok(ModelArch::Qwen3MoeQuantized(weights))
}
"llama" => {
let weights = QuantizedLlamaWeights::from_gguf(content, &mut file, device)
.map_err(|e| anyhow::anyhow!("from_gguf llama: {e}"))?;
Ok(ModelArch::LlamaQuantized(weights))
}
other => anyhow::bail!(
"unsupported GGUF architecture '{other}'; quantized path supports \
qwen3, qwen3moe, llama"
),
}
}
/// Load a dense safetensors model on the worker thread.
fn load_dense_inner(
device: &candle_core::Device,
config_path: &std::path::Path,
safetensors_paths: &[std::path::PathBuf],
model_id: &str,
) -> anyhow::Result<ModelArch> {
use anyhow::Context;
use candle_core::DType;
use candle_nn::VarBuilder;
use candle_transformers::models::llama as llama_dense;
use candle_transformers::models::qwen3 as qwen3_dense;
use candle_transformers::models::qwen3_moe as qwen3_moe_dense;
let cfg_text = std::fs::read_to_string(config_path).context("read config.json")?;
crate::harness::candle::check_dense_config_supported(&cfg_text, model_id)?;
// Peek at model_type to choose the family before the typed
// deserialize — each family has its own Config.
let model_type = serde_json::from_str::<serde_json::Value>(&cfg_text)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
tracing::info!(
model = %model_id,
model_type = %model_type,
shards = safetensors_paths.len(),
"loading dense model from safetensors"
);
// bf16 is the canonical distribution dtype for Qwen3 / Llama 3 /
// Qwen3 MoE. CUDA on Ada+ has hardware bf16; Ampere has it too.
// CPU emulates.
let dtype = DType::BF16;
// SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files;
// mutation by another process while we hold the mapping is UB.
// We trust the HF cache is immutable-by-design.
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(safetensors_paths, dtype, device)
.context("build VarBuilder over safetensors")?
};
match model_type.as_str() {
"qwen3" => {
let cfg: qwen3_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?;
let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?;
Ok(ModelArch::Qwen3Dense(model))
}
"qwen3_moe" => {
let cfg: qwen3_moe_dense::Config =
serde_json::from_str(&cfg_text).context("parse Qwen3 MoE config.json")?;
let model = qwen3_moe_dense::ModelForCausalLM::new(&cfg, vb)
.map_err(|e| anyhow::anyhow!("build Qwen3 MoE dense model: {e}"))?;
Ok(ModelArch::Qwen3MoeDense(model))
}
"llama" => {
let cfg: llama_dense::LlamaConfig =
serde_json::from_str(&cfg_text).context("parse Llama config.json")?;
let config = cfg.into_config(false);
let cache = llama_dense::Cache::new(true, dtype, &config, device)
.context("build Llama Cache")?;
let model = llama_dense::Llama::load(vb, &config)
.map_err(|e| anyhow::anyhow!("build Llama dense model: {e}"))?;
Ok(ModelArch::LlamaDense(Box::new(
crate::harness::candle::LlamaDense::from_parts(
model,
cache,
config,
dtype,
device.clone(),
),
)))
}
"qwen3_5" => {
let cfg: crate::harness::arch::qwen3_5::Config = serde_json::from_str(&cfg_text)
.context("parse Qwen3-Next (qwen3_5) config.json")?;
let sharded_vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(
safetensors_paths,
dtype,
device,
)
.context("build ShardedVarBuilder for Qwen3-Next")?
};
let model = crate::harness::arch::qwen3_5::Qwen3_5ForCausalLM::new(cfg, sharded_vb)
.context("build Qwen3-Next dense model")?;
Ok(ModelArch::Qwen3_5Dense(model))
}
other => anyhow::bail!(
"unrouted supported model_type '{other}' — \
DENSE_SUPPORTED_MODEL_TYPES and load_dense_inner \
must stay in sync"
),
}
}
/// Load the leader's TP shard on the worker thread. Reads the Comm
/// directly from `state.nccl`; no cross-thread Arc<Comm> transfer.
#[cfg(feature = "cuda")]
fn tp_load_shard_inner(
state: &mut DeviceWorkerState,
model_id: &str,
config_json: &str,
safetensors_paths: &[std::path::PathBuf],
dtype: candle_core::DType,
quant: Option<&str>,
world_size: u32,
) -> anyhow::Result<TpHandle> {
use anyhow::Context;
use candle_nn::var_builder::ShardedSafeTensors;
let comm = state.nccl.comm().ok_or_else(|| {
anyhow::anyhow!("TpLoadShard: NcclState has no Comm; call NcclInit first")
})?;
let model_type = serde_json::from_str::<serde_json::Value>(config_json)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
// SAFETY: same invariant as the single-GPU dense path — the HF
// cache files are treated as immutable while the mmap is held.
let vb = unsafe {
ShardedSafeTensors::var_builder(safetensors_paths, dtype, &state.device)
.context("build ShardedVarBuilder over safetensors")?
};
let mmap = unsafe {
candle_core::safetensors::MmapedSafetensors::multi(safetensors_paths)
.context("build MmapedSafetensors for leader load")?
};
let loaded = match model_type.as_str() {
"qwen3" => {
let cfg: crate::harness::tp::tp_qwen3::Config = serde_json::from_str(config_json)
.context("parse Qwen3 Config JSON for leader load")?;
TpLeaderModel::Qwen3(crate::harness::tp::tp_qwen3::TpQwen3ForCausalLM::load(
&cfg, &vb, 0, world_size, comm,
)?)
}
"qwen3_5" => {
let cfg: crate::harness::tp::tp_qwen3_5::Config = serde_json::from_str(config_json)
.context("parse Qwen3-Next Config JSON for leader load")?;
let quant_dtype = crate::harness::tp::worker::parse_quant_string(quant)?;
TpLeaderModel::Qwen3_5(crate::harness::tp::tp_qwen3_5::TpQwen3_5ForCausalLM::load(
cfg,
&vb,
&mmap,
0,
world_size,
comm,
quant_dtype,
)?)
}
other => anyhow::bail!(
"TP dispatch: unsupported model_type '{other}' on leader (supported: qwen3, qwen3_5)"
),
};
tracing::info!(
rank = 0,
model = %model_id,
model_type = %model_type,
"loaded TP shard (leader)"
);
let handle = TpHandle(state.next_tp_handle);
state.next_tp_handle = state.next_tp_handle.wrapping_add(1);
state.tp_models.insert(handle, Box::new(loaded));
tracing::debug!(
device_index = state.device_index,
tp_handle = handle.0,
slab_size = state.tp_models.len(),
"device worker: TP model inserted"
);
Ok(handle)
}
/// TP-equivalent of [`forward_logits`]: looks up the leader's
/// [`TpLeaderModel`] in the slab, runs its forward, copies the
/// `[vocab]` logits to a CPU `Vec<f32>`. The leader's `Arc<Comm>`
/// clones embedded in the TP layers' AllReduce ops fire from this
/// thread — same thread that bound the CUDA context and that holds
/// the `Comm` in `state.nccl`.
#[cfg(feature = "cuda")]
fn tp_forward_logits(
state: &mut DeviceWorkerState,
handle: TpHandle,
tokens: &[u32],
offset: usize,
) -> anyhow::Result<Vec<f32>> {
use candle_core::{DType, Tensor};
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
let model = state
.tp_models
.get_mut(&handle)
.ok_or_else(|| anyhow::anyhow!("TpForwardLogits: no model for handle {}", handle.0))?;
let logits = model.forward(&input, offset)?;
// ForCausalLM forward returns [B, 1, V] after the trailing
// .i((.., l - 1.., ..))?.apply(lm_head); squeeze both leading
// singleton dims to a rank-1 [V] tensor for sampling.
let logits = logits.squeeze(0)?.squeeze(0)?;
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
let values = logits.to_vec1::<f32>()?;
Ok(values)
}
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
/// for sampling on the async caller. The model's `device()` (CUDA or
/// CPU) determines where the kernel runs; this fn doesn't care.
///
/// On CUDA, the `to_dtype(F32).flatten_all().to_vec1::<f32>()` chain
/// triggers the device → host copy. The copy runs synchronously on
/// this worker thread; the bound context owns the source allocation
/// so the transfer is straightforward.
fn forward_logits(
state: &mut DeviceWorkerState,
handle: ArchHandle,
tokens: &[u32],
offset: usize,
) -> anyhow::Result<Vec<f32>> {
use candle_core::{DType, Tensor};
// Build the input tensor on the worker's own device. cudarc's
// primary-context model means `Device::new_cuda(idx)` shares state
// with the `CudaContext` we bound at startup, so this is the same
// device the ModelArch was loaded against.
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
let arch = state
.models
.get_mut(&handle)
.ok_or_else(|| anyhow::anyhow!("ForwardLogits: no model for handle {}", handle.0))?;
let logits = arch.forward(&input, offset)?;
// Copy to CPU f32. logits is already `[vocab]` (squeeze_to_vocab
// inside ModelArch::forward). The to_dtype handles bf16/f16 →
// f32 promotion for the sampler.
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
let values = logits.to_vec1::<f32>()?;
Ok(values)
}
/// Reply to a job with the poisoned-worker error. Used when the worker
/// has flipped into drain-only mode after a CUDA driver error.
///
/// `Job::Shutdown` is filtered before reaching this fn so the match
/// only needs the data-carrying variants. As phases 24 add more
/// variants the match here grows; every variant must reply with the
/// poisoned error so callers never hang waiting for a worker that's
/// no longer running CUDA.
fn drain_poisoned(job: Job, device_index: u32) {
let err = || anyhow::anyhow!("device worker for device {device_index} is poisoned");
match job {
Job::QueryVram { reply } => {
let _ = reply.send(Err(err()));
}
Job::LoadGguf { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::LoadDense { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::DropArch { reply, .. } => {
// Drop reply is `()` — no error path. Send the unit so the
// caller's await resolves; the model handle is leaked in
// the worker's slab, but the whole slab gets `mem::forget`
// on shutdown anyway per the poisoned-thread design.
let _ = reply.send(());
}
Job::ClearKv { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::ForwardLogits { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::NcclInit { reply, .. } => {
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
kind: "device_worker_poisoned".into(),
message: format!("device worker {device_index} poisoned"),
});
}
Job::NcclSanity { reply } => {
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
kind: "device_worker_poisoned".into(),
message: format!("device worker {device_index} poisoned"),
});
}
#[cfg(feature = "cuda")]
Job::TpLoadShard { reply, .. } => {
let _ = reply.send(Err(err()));
}
#[cfg(feature = "cuda")]
Job::DropTp { reply, .. } => {
let _ = reply.send(());
}
#[cfg(feature = "cuda")]
Job::TpClearKv { reply, .. } => {
let _ = reply.send(Err(err()));
}
#[cfg(feature = "cuda")]
Job::TpForwardLogits { reply, .. } => {
let _ = reply.send(Err(err()));
}
Job::Shutdown => {
// Filtered by the matches!() guard in run(); reaching
// here would be a logic error.
unreachable!("Shutdown is filtered before drain_poisoned");
}
}
}

View File

@@ -1,169 +0,0 @@
//! Job variants accepted by the per-device worker thread.
//!
//! Each variant carries the inputs the synchronous dispatch handler
//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
use anyhow::Result;
use std::path::PathBuf;
use tokio::sync::oneshot;
/// Opaque handle to a `ModelArch` stored in the worker thread's state
/// slab. Cheap to copy; `Send + Sync` so it crosses task boundaries
/// freely. The actual `Box<ModelArch>` it points to is owned by the
/// worker thread for the duration of the handle's lifetime — the only
/// way to drop the model is to send `Job::DropArch { handle }` so the
/// `Drop` impl runs on the thread with the bound CUDA context (the
/// invariant the whole refactor exists to guarantee).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ArchHandle(pub u64);
/// Opaque handle to a `TpLeaderModel` stored in the worker thread's
/// state slab. Same shape as [`ArchHandle`] but in a separate
/// namespace so the two slabs can coexist without ambiguity. Phase 3
/// introduces it; Phase 4 may unify the two slabs after the TP forward
/// path proves out.
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TpHandle(pub u64);
/// One unit of work for the device worker.
///
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
/// single-GPU inference primitives: transfer-in a freshly-loaded
/// `ModelArch`, drop it, clear its KV cache, and run one forward step
/// returning CPU-side logits ready for sampling on the async caller.
///
/// Sampling stays on the async side intentionally. The worker copies
/// logits to CPU (`Vec<f32>`) before reply, so the device-resident
/// tensor never escapes the worker thread and the async caller's
/// `LogitsProcessor::sample` runs entirely on the CPU candle backend
/// — no incidental context binding on a tokio worker thread.
pub enum Job {
/// Query free / total VRAM on the device. Returns
/// `(free_mb, total_mb)`. CPU builds and contexts that failed to
/// initialise reply with `(0, 0)` — matches today's
/// `device_vram_mb` sentinel so the log field values don't change.
QueryVram {
reply: oneshot::Sender<Result<(u64, u64)>>,
},
/// Load a GGUF (pre-quantized) single-GPU model on the worker
/// thread. The dispatch handler opens the GGUF file, parses
/// metadata, dispatches on `general.architecture`, and inserts
/// the resulting `ModelArch` into the slab. Returns the fresh
/// `ArchHandle`.
LoadGguf {
gguf_path: PathBuf,
model_id: String,
reply: oneshot::Sender<Result<ArchHandle>>,
},
/// Load a dense safetensors single-GPU model on the worker
/// thread. The dispatch handler reads `config.json`, dispatches on
/// `model_type`, builds a `VarBuilder` over the mmap'd
/// safetensors, and inserts the resulting `ModelArch`.
LoadDense {
config_path: PathBuf,
safetensors_paths: Vec<PathBuf>,
model_id: String,
reply: oneshot::Sender<Result<ArchHandle>>,
},
/// Remove the model from the slab and drop it. The `Drop` runs on
/// the worker thread so CUDA tensors release their memory on the
/// same context that allocated them.
DropArch {
handle: ArchHandle,
reply: oneshot::Sender<()>,
},
/// Reset the KV cache for this model. Called at the start of every
/// chat completion so a new request doesn't attend over the
/// previous one's tokens.
ClearKv {
handle: ArchHandle,
reply: oneshot::Sender<Result<()>>,
},
/// Run one forward step and copy the resulting `[vocab]` logits to
/// CPU. The caller takes the returned `Vec<f32>`, wraps it in a
/// CPU `Tensor`, and runs `apply_repeat_penalty` + sampling
/// without touching the device context. `offset` is the KV-cache
/// position before this step (0 for prefill, `prompt_len + i` for
/// the i-th decode step).
ForwardLogits {
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Initialize the leader's NCCL communicator. The worker's
/// `NcclState` mints the `Comm` here so its underlying
/// `ncclComm_t` and `CudaContext` live on the same thread as
/// every later `Comm::all_reduce` call. Reply is the worker
/// response shape used by the subprocess workers (`InitOk` on
/// success, `Error` on failure) so the calling
/// `WorkerPool::init_nccl` orchestration stays uniform.
///
/// Available on both cuda and no-cuda builds — the dispatch
/// handler calls `NcclState::init` which has a no-cuda stub that
/// replies with `cuda_feature_not_enabled`. Keeping the Job
/// variant ungated lets `WorkerPool::init_nccl` stay uniform.
NcclInit {
cfg: crate::harness::tp::worker::WorkerConfig,
comm_id_hex: String,
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
},
/// Run NCCL's all_reduce sanity check on the leader's rank 0.
/// Same response shape as `NcclInit`; also available on both
/// builds via the no-cuda `NcclState::sanity_check` stub.
NcclSanity {
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
},
/// Load the leader's TP shard on the worker thread. The dispatch
/// handler reads `state.nccl.comm()` directly (no cross-thread
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
/// `TpLeaderModel` against that Comm. The model's embedded
/// `Arc<Comm>` clones, `CudaContext`, and all per-rank CUDA
/// tensors live on this thread for the model's lifetime.
/// Inserts into the TP slab and returns the fresh `TpHandle`.
#[cfg(feature = "cuda")]
TpLoadShard {
model_id: String,
config_json: String,
safetensors_paths: Vec<PathBuf>,
dtype: candle_core::DType,
quant: Option<String>,
world_size: u32,
reply: oneshot::Sender<Result<TpHandle>>,
},
/// Drop the TP leader model on the worker thread. CUDA tensors
/// and `Arc<Comm>` clones held inside the model release on the
/// thread that allocated them.
#[cfg(feature = "cuda")]
DropTp {
handle: TpHandle,
reply: oneshot::Sender<()>,
},
/// Reset the leader's KV cache for a TP model. Mirrors `ClearKv`
/// for single-GPU.
#[cfg(feature = "cuda")]
TpClearKv {
handle: TpHandle,
reply: oneshot::Sender<Result<()>>,
},
/// Run one TP forward step on the leader's shard. Returns CPU-
/// side logits as a `Vec<f32>` so the async caller can sample
/// without holding a device tensor. The caller is also
/// responsible for fan-out to subprocess ranks and drain — only
/// the leader's forward moves into the worker thread.
#[cfg(feature = "cuda")]
TpForwardLogits {
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Tell the worker to break its dispatch loop and exit. Any jobs
/// queued after this in the channel reply `Err` to their oneshot
/// senders (the senders are dropped on the worker's exit, which
/// the async-side `Receiver::await` maps to `WorkerError::Gone`).
Shutdown,
}

View File

@@ -1,592 +0,0 @@
//! Per-device CUDA worker thread.
//!
//! One dedicated OS thread per CUDA device the leader uses. The thread
//! binds the device's `CudaContext` once at startup and owns it for the
//! daemon's lifetime; all GPU operations and VRAM queries for that
//! device route through a `std::sync::mpsc` channel into this thread.
//! Tensors never escape the thread alive — replies cross the channel
//! as plain values (`u32` tokens, `(u64, u64)` mb numbers, `()`).
//!
//! Rationale, in order of weight:
//!
//! 1. **Context locality.** cudarc binds the CUDA context per OS thread
//! via `cuCtxSetCurrent`. With `tokio::task::spawn_blocking`, the
//! blocking thread chosen is arbitrary, so the context gets bound
//! onto a different thread each time and `device_vram_mb()` from an
//! async task binds it again on the *caller's* thread as a side
//! effect. Pinning the context to one named thread ends that.
//!
//! 2. **Drop safety.** `cudarc::driver::CudaContext`, every `CudaSlice`
//! inside a `Tensor`, and every `cudarc::nccl::Comm` call `cuMemFree`
//! / `cuCtxDestroy` / `ncclCommDestroy` during `Drop`. These must
//! run with the right context current. Owning everything in this
//! thread's state slab and dropping it via `Job::DropArch` /
//! `Job::Shutdown` is the only safe pattern.
//!
//! 3. **Poisoning blast radius.** When a CUDA driver error (illegal
//! address, OOM cascade) makes the context unrecoverable, today the
//! spawn_blocking thread carrying that bad state simply returns to
//! tokio's pool — invisible. With the per-device thread, the
//! poisoned flag lives on the thread itself; subsequent
//! `submit()` calls fast-reject at the channel boundary with a
//! clear "device worker is poisoned" error before any further CUDA
//! work is attempted.
//!
//! The TP worker subprocesses (`harness/tp/worker.rs`) are already this
//! pattern, just out-of-process. The in-process variant uses the same
//! discipline for rank 0.
//!
//! Phase 1 of the refactor exposes only `Job::QueryVram` + `Job::Shutdown`.
//! Forward, kv-cache clear, model load, and NCCL bring-up move in later
//! phases. See `/home/grenade/.claude/plans/plan-the-per-device-worker-abstract-micali.md`.
pub mod dispatch;
pub mod jobs;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{self, Sender};
use std::thread::JoinHandle;
use tokio::sync::oneshot;
#[cfg(feature = "cuda")]
pub use jobs::TpHandle;
pub use jobs::{ArchHandle, Job};
/// Errors returned by `DeviceWorkerHandle` submit methods.
#[derive(Debug, thiserror::Error)]
pub enum WorkerError {
/// The worker's CUDA context was poisoned by an earlier driver
/// error. The thread is still alive (dropping it would re-touch
/// the broken context); it returns this error for every job
/// submitted until the daemon is restarted.
#[error(
"device worker for device {device_index} is poisoned \
(a prior CUDA driver error left the context unrecoverable); \
restart the daemon to recover"
)]
Poisoned { device_index: u32 },
/// The worker thread has exited (`Job::Shutdown` was processed or
/// the thread panicked). Subsequent `submit()` calls fail here
/// rather than blocking forever.
#[error("device worker for device {device_index} is no longer running")]
Gone { device_index: u32 },
/// The dispatched job returned an `Err`. Forwarded verbatim.
#[error(transparent)]
Job(#[from] anyhow::Error),
}
/// Shared handle to a per-device CUDA worker thread.
///
/// Cloning the `Arc` lets multiple `LoadedModel`s (and `TpLoadedModel`s)
/// share the same worker — there's one worker per CUDA device index,
/// not one per model.
pub struct DeviceWorkerHandle {
device_index: u32,
tx: Sender<Job>,
poisoned: Arc<AtomicBool>,
/// `Mutex<Option<JoinHandle>>` so `shutdown()` can take the handle
/// out without `&mut self` and so the inevitable `Drop` after
/// `shutdown()` doesn't double-join. The mutex is uncontended in
/// practice: only one caller ever takes the handle.
join: std::sync::Mutex<Option<JoinHandle<()>>>,
}
impl DeviceWorkerHandle {
/// Spawn a new worker for the given CUDA device index.
///
/// The thread is named `cuda-dev-N` so it shows up legibly in
/// `top -H`, `pidstat -t`, and gdb backtraces. On CUDA builds, the
/// thread binds `CudaContext::new(N)` on startup; on CPU builds
/// (`--no-default-features`) the thread runs without a context and
/// every job that touches CUDA falls through to a zero return.
pub fn spawn(device_index: u32) -> anyhow::Result<Arc<Self>> {
let (tx, rx) = mpsc::channel::<Job>();
let poisoned = Arc::new(AtomicBool::new(false));
let poisoned_for_thread = Arc::clone(&poisoned);
let join = std::thread::Builder::new()
.name(format!("cuda-dev-{device_index}"))
.spawn(move || {
dispatch::run(device_index, rx, poisoned_for_thread);
})?;
Ok(Arc::new(Self {
device_index,
tx,
poisoned,
join: std::sync::Mutex::new(Some(join)),
}))
}
pub fn device_index(&self) -> u32 {
self.device_index
}
pub fn is_poisoned(&self) -> bool {
self.poisoned.load(Ordering::Acquire)
}
/// Mark the worker's context as poisoned. Future `submit()` calls
/// short-circuit to `WorkerError::Poisoned` before sending. The
/// dispatch loop also flips into drain-only mode when it sees this
/// flag, so any jobs already in flight on the channel reply with
/// the same error without touching CUDA.
#[allow(dead_code)]
pub(crate) fn set_poisoned(&self) {
self.poisoned.store(true, Ordering::Release);
}
/// Send `Job::QueryVram`, await the worker's reply.
///
/// Returns `Ok((free_mb, total_mb))` on success, `Ok((0, 0))` on
/// CPU builds or when the device lacks a bound context, or an
/// error if the worker is poisoned, gone, or the query itself
/// failed inside cudarc.
pub async fn query_vram(&self) -> Result<(u64, u64), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::QueryVram { reply: reply_tx })
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Load a GGUF (pre-quantized) single-GPU model on the worker
/// thread. The hf-hub resolution happens on the async caller; the
/// resolved local `gguf_path` plus the spec's model_id are sent
/// into the worker which opens, parses, and constructs the
/// `ModelArch` on the right thread.
pub async fn load_gguf(
&self,
gguf_path: std::path::PathBuf,
model_id: String,
) -> Result<ArchHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::LoadGguf {
gguf_path,
model_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Load a dense safetensors single-GPU model on the worker thread.
pub async fn load_dense(
&self,
config_path: std::path::PathBuf,
safetensors_paths: Vec<std::path::PathBuf>,
model_id: String,
) -> Result<ArchHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::LoadDense {
config_path,
safetensors_paths,
model_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Tell the worker to drop the `ModelArch` for `handle` on the
/// worker thread (so CUDA tensors release on the right context).
/// Returns `Ok(())` even if the handle wasn't in the slab — Drop
/// is idempotent. Reports `Gone` if the worker isn't running.
pub async fn drop_arch(&self, handle: ArchHandle) -> Result<(), WorkerError> {
// Poisoning doesn't block DropArch — even on a poisoned
// context we want callers to unblock and proceed with the
// unload bookkeeping. The dispatch handler under poison just
// replies `()` without touching the model (the actual Drop
// happens via mem::forget at thread exit per the poison
// protocol).
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::DropArch {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(()) => Ok(()),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Reset the KV cache for the model at `handle`. Called at the
/// start of every chat completion so the new prompt doesn't
/// attend over the previous request's tokens.
pub async fn clear_kv_cache(&self, handle: ArchHandle) -> Result<(), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::ClearKv {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Run one forward step and return the resulting `[vocab]` logits
/// as a CPU-side `Vec<f32>`. The caller then samples on a CPU
/// candle Tensor without ever binding the device context on its
/// tokio thread.
pub async fn forward_logits(
&self,
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::ForwardLogits {
handle,
tokens,
offset,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Initialise the leader's NCCL communicator. The reply uses
/// `WorkerResponse` (same shape subprocess workers use over stdio
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
/// subprocess responses uniformly. Available on no-cuda builds
/// too — the dispatch handler calls the no-cuda `NcclState::init`
/// stub which replies `cuda_feature_not_enabled`.
pub async fn nccl_init(
&self,
cfg: crate::harness::tp::worker::WorkerConfig,
comm_id_hex: String,
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::NcclInit {
cfg,
comm_id_hex,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
reply_rx.await.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})
}
/// Run an NCCL sanity all_reduce on the leader's rank 0.
/// Available on no-cuda builds; replies with an error response.
pub async fn nccl_sanity(
&self,
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::NcclSanity { reply: reply_tx })
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
reply_rx.await.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})
}
/// Load the leader's TP shard on the worker thread. The dispatch
/// handler reads its own `NcclState`'s `Arc<Comm>` directly — no
/// cross-thread Comm transfer — and builds the `TpLeaderModel`
/// against it. Phase 4 replaces the Phase 3 Clone/TransferIn
/// bridge with this single Job.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn tp_load_shard(
&self,
model_id: String,
config_json: String,
safetensors_paths: Vec<std::path::PathBuf>,
dtype: candle_core::DType,
quant: Option<String>,
world_size: u32,
) -> Result<TpHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpLoadShard {
model_id,
config_json,
safetensors_paths,
dtype,
quant,
world_size,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Drop the TP model at `handle` on the worker thread.
#[cfg(feature = "cuda")]
pub async fn drop_tp(&self, handle: TpHandle) -> Result<(), WorkerError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::DropTp {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(()) => Ok(()),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Reset the leader's KV cache for a TP model.
#[cfg(feature = "cuda")]
pub async fn tp_clear_kv(&self, handle: TpHandle) -> Result<(), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpClearKv {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Run one TP forward step on the leader's shard. Returns CPU-side
/// logits as `Vec<f32>` ready for sampling. The caller is
/// responsible for fan-out / drain of the subprocess workers
/// concurrently with this call.
#[cfg(feature = "cuda")]
pub async fn tp_forward_logits(
&self,
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpForwardLogits {
handle,
tokens,
offset,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
/// twice is a no-op the second time.
pub fn shutdown(&self) -> anyhow::Result<()> {
// Best-effort send: if the channel is already closed (thread
// exited after a prior shutdown or panic) the send fails and
// we fall through to the join which returns the panic, if any.
let _ = self.tx.send(Job::Shutdown);
let join = self.join.lock().unwrap().take();
if let Some(j) = join {
j.join()
.map_err(|_| anyhow::anyhow!("worker thread panicked during shutdown"))?;
}
Ok(())
}
}
impl Drop for DeviceWorkerHandle {
fn drop(&mut self) {
// Best-effort: send Shutdown so the thread breaks its loop
// and exits. We do NOT join here — Drop may run on a tokio
// worker thread, and joining a thread that's still processing
// the last job would block the runtime. The OS reaps the
// thread on detach.
let _ = self.tx.send(Job::Shutdown);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn spawn_query_vram_shutdown() {
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
// CPU build (the only one CI runs) returns (0, 0) by design;
// a CUDA build with a real device would return real values.
let result = handle.query_vram().await.expect("query ok");
// We assert >= 0 — the field width matters more than the value.
let _ = result.0;
let _ = result.1;
handle.shutdown().expect("shutdown ok");
}
#[tokio::test]
async fn thread_is_named_correctly() {
// The thread name lets `top -H` / pidstat / gdb show
// `cuda-dev-N` instead of an opaque tokio worker name. Verify
// by spawning and reading proc-self thread comms — but on
// platforms without /proc, just confirm we don't crash.
let handle = DeviceWorkerHandle::spawn(7).expect("spawn ok");
// Round-trip a job to ensure the thread is alive and processing.
handle.query_vram().await.expect("query ok");
handle.shutdown().expect("shutdown ok");
}
#[tokio::test]
async fn submit_after_shutdown_returns_gone() {
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
handle.shutdown().expect("shutdown ok");
// Channel closed; submit should map to Gone rather than block.
let result = handle.query_vram().await;
match result {
Err(WorkerError::Gone { device_index: 0 }) => {}
other => panic!("expected Gone, got {other:?}"),
}
}
#[tokio::test]
async fn poisoned_flag_short_circuits_submit() {
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
handle.set_poisoned();
let result = handle.query_vram().await;
match result {
Err(WorkerError::Poisoned { device_index: 0 }) => {}
other => panic!("expected Poisoned, got {other:?}"),
}
// The channel is still alive; shutdown should still succeed.
handle.shutdown().expect("shutdown ok");
}
#[tokio::test]
async fn shutdown_drains_pending_jobs() {
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
// Submit many concurrent jobs; they should all complete even
// though a Shutdown is racing them.
let mut futures = Vec::new();
for _ in 0..16 {
let h = Arc::clone(&handle);
futures.push(tokio::spawn(async move { h.query_vram().await }));
}
// Small yield to give the senders a chance to actually send
// before we issue the shutdown; not strictly necessary because
// the channel is FIFO, but makes the test's intent clearer.
tokio::time::sleep(Duration::from_millis(10)).await;
handle.shutdown().expect("shutdown ok");
for f in futures {
// Each query should have completed (Ok or Gone, never panic).
let _ = f.await.expect("task did not panic");
}
}
}

View File

@@ -0,0 +1 @@
// llama.cpp harness implementation — Phase 11.

View File

@@ -0,0 +1,163 @@
//! mistral.rs harness implementation.
//!
//! Wraps the mistral.rs HTTP API for model lifecycle management
//! and optionally manages the process via systemd.
use anyhow::Result;
use async_trait::async_trait;
use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec};
use reqwest::Client;
use serde::Deserialize;
pub struct MistralRsHarness {
endpoint: String,
systemd_unit: Option<String>,
client: Client,
}
impl MistralRsHarness {
pub fn new(endpoint: String, systemd_unit: Option<String>) -> Self {
Self {
endpoint,
systemd_unit,
client: Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("failed to build HTTP client"),
}
}
}
/// Response from mistral.rs `GET /v1/models`.
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelEntry>,
}
#[derive(Debug, Deserialize)]
struct ModelEntry {
id: String,
#[serde(default)]
status: Option<String>,
}
#[async_trait]
impl Harness for MistralRsHarness {
fn name(&self) -> &str {
"mistralrs"
}
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
let Some(unit) = &self.systemd_unit else {
anyhow::bail!("no systemd unit configured for mistralrs harness");
};
let output = tokio::process::Command::new("systemctl")
.args(["start", unit])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl start {unit} failed: {stderr}");
}
// Wait for the health endpoint to respond (up to 30s).
let url = format!("{}/health", self.endpoint);
for _ in 0..30 {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if self.client.get(&url).send().await.is_ok() {
tracing::info!(unit, "mistralrs started and healthy");
return Ok(());
}
}
anyhow::bail!("mistralrs started but health endpoint did not respond within 30s");
}
async fn stop(&self) -> Result<()> {
let Some(unit) = &self.systemd_unit else {
anyhow::bail!("no systemd unit configured for mistralrs harness");
};
let output = tokio::process::Command::new("systemctl")
.args(["stop", unit])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl stop {unit} failed: {stderr}");
}
Ok(())
}
async fn health(&self) -> HarnessHealth {
let url = format!("{}/health", self.endpoint);
let running = self.client.get(&url).send().await.is_ok();
HarnessHealth {
name: "mistralrs".into(),
running,
uptime_secs: None,
}
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let url = format!("{}/v1/models", self.endpoint);
let resp = self.client.get(&url).send().await?;
if !resp.status().is_success() {
anyhow::bail!("GET /v1/models returned {}", resp.status());
}
let models_resp: ModelsResponse = resp.json().await?;
Ok(models_resp
.data
.into_iter()
.map(|m| ModelInfo {
id: m.id,
harness: "mistralrs".into(),
status: m.status.unwrap_or_else(|| "loaded".into()),
devices: vec![],
vram_used_mb: None,
})
.collect())
}
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
let url = format!("{}/v1/models/reload", self.endpoint);
let resp = self
.client
.post(&url)
.json(&serde_json::json!({ "model_id": spec.model_id }))
.send()
.await?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("POST /v1/models/reload failed: {body}");
}
Ok(())
}
async fn unload_model(&self, model_id: &str) -> Result<()> {
let url = format!("{}/v1/models/unload", self.endpoint);
let resp = self
.client
.post(&url)
.json(&serde_json::json!({ "model_id": model_id }))
.send()
.await?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("POST /v1/models/unload failed: {body}");
}
Ok(())
}
async fn inference_endpoint(&self, _model_id: &str) -> Option<String> {
// mistral.rs routes internally by model name in the request body,
// so the inference endpoint is always the base URL.
Some(self.endpoint.clone())
}
}

View File

@@ -1,27 +1,15 @@
//! Harness registry — maps harness names to trait implementations.
pub mod arch;
pub mod candle;
pub mod chat_template;
pub mod device_worker;
pub mod preflight;
pub mod tp;
pub mod llamacpp;
pub mod mistralrs;
use anyhow::Result;
use cortex_core::harness::{Harness, HarnessConfig, ModelInfo, ModelSpec};
use std::collections::HashMap;
use std::sync::Arc;
/// Registry of available harness implementations.
///
/// Holds an `Arc<dyn Harness>` per harness for generic lifecycle dispatch
/// (load/unload/list_models). When a candle harness is registered, a typed
/// `Arc<CandleHarness>` is also cached so inference routes can bypass the
/// dyn-Trait dispatch and reach harness-specific methods (chat completion,
/// streaming, etc.).
pub struct HarnessRegistry {
harnesses: HashMap<String, Arc<dyn Harness>>,
candle: Option<Arc<candle::CandleHarness>>,
harnesses: HashMap<String, Box<dyn Harness>>,
}
impl Default for HarnessRegistry {
@@ -34,11 +22,10 @@ impl HarnessRegistry {
pub fn new() -> Self {
Self {
harnesses: HashMap::new(),
candle: None,
}
}
pub fn register(&mut self, harness: Arc<dyn Harness>) {
pub fn register(&mut self, harness: Box<dyn Harness>) {
self.harnesses.insert(harness.name().to_string(), harness);
}
@@ -47,12 +34,6 @@ impl HarnessRegistry {
self.harnesses.keys().cloned().collect()
}
/// Typed handle to the candle harness, if registered. Used by inference
/// routes that need methods beyond the `Harness` trait surface.
pub fn candle(&self) -> Option<Arc<candle::CandleHarness>> {
self.candle.clone()
}
/// List models from all registered harnesses.
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
let mut all = Vec::new();
@@ -100,25 +81,19 @@ impl HarnessRegistry {
}
/// Build a registry from harness configs.
///
/// `bind_url` is the URL where this neuron serves inference (its own
/// listen address). In-process harnesses (currently the only kind)
/// return this URL from `inference_endpoint`.
pub fn from_configs(
configs: &[HarnessConfig],
bind_url: &str,
settings: &crate::config::HarnessSettings,
) -> Self {
pub fn from_configs(configs: &[HarnessConfig]) -> Self {
let mut registry = Self::new();
for config in configs {
match config.name.as_str() {
"candle" => {
let harness = Arc::new(candle::CandleHarness::new(
bind_url.to_string(),
settings.candle.hf_cache.clone(),
));
registry.candle = Some(Arc::clone(&harness));
registry.harnesses.insert("candle".into(), harness);
"mistralrs" => {
if let Some(endpoint) = &config.endpoint {
registry.register(Box::new(mistralrs::MistralRsHarness::new(
endpoint.clone(),
config.systemd_unit.clone(),
)));
} else {
tracing::warn!("mistralrs harness missing endpoint, skipping");
}
}
other => {
tracing::warn!(harness = other, "unknown harness type, skipping");

View File

@@ -1,575 +0,0 @@
//! Placement feasibility check that runs before any device allocation,
//! NCCL handshake, or weight download.
//!
//! The loader path in `candle.rs` historically discovers an
//! incompatibility *after* it has already started fetching files —
//! "fetch config.json from HauhauCS/...: 404 Not Found" surfaces hours
//! after operators set `tensor_parallel = 2` on a GGUF-only repo, with
//! no hint about what's actually wrong. Preflight closes that gap:
//!
//! 1. one `repo.info()` round-trip (siblings listing, no blob fetch)
//! 2. classify the repo: GGUF-only, dense safetensors, mixed, empty
//! 3. apply the feasibility table against the requested
//! `ModelSpec` (tp_size, quant)
//! 4. return a structured `PreflightError` the API layer can map to
//! 422 + JSON, or `Ok(PlacementPlan)` carrying the decisions the
//! downstream load path needs (which GGUF file to fetch, etc.).
//!
//! Phase 2 of plan-source-aware-loader-preflight. The Phase 1 scheme
//! work — `ModelSourceId` and per-scheme `SourceConfig` — is a
//! separate PR; preflight runs against the single configured
//! HuggingFace source for now and the scheme threading drops in
//! cleanly when Phase 1 lands.
use cortex_core::harness::ModelSpec;
use hf_hub::api::tokio::Api;
use serde::Serialize;
/// What the repo's siblings listing tells us about how to load it.
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SourceFormat {
/// Only GGUF files present. Single-GPU load path. `quants` is the
/// lowercased filename list so the operator can be told what's
/// actually available when their `quant=` choice doesn't match.
Gguf { quants: Vec<String> },
/// Dense safetensors (single-file or sharded via index.json).
/// Goes through `load_arch_dense` on single-GPU, or `load_tp` (with
/// optional in-situ quantization) when `tensor_parallel > 1`.
DenseSafetensors { sharded: bool },
/// Both safetensors and GGUF present — prefer the dense path
/// because it composes with TP and ISQ. We surface the GGUF
/// filenames anyway so operators with a strong preference can
/// see they exist.
Mixed { gguf_quants: Vec<String> },
/// No recognised weight files. Either a tokenizer-only repo
/// (e.g. some base-model repos that only host `tokenizer.json` and
/// expect the operator to use a `-GGUF` sibling repo) or a
/// genuinely empty entry.
Empty,
}
/// Output of `preflight` for a load that can proceed. Carries the
/// decisions downstream resolve_* paths would otherwise re-derive.
#[derive(Debug, Clone, Serialize)]
pub struct PlacementPlan {
pub model_id: String,
pub format: SourceFormat,
pub tp_size: u32,
/// Filename of the GGUF to fetch, populated when `format` is
/// `Gguf` and a single-GPU load was requested. None for the
/// dense/TP path.
pub picked_quant_file: Option<String>,
}
/// Structured failure modes. Each variant carries the fields the API
/// layer needs to produce an actionable 422 body.
#[derive(Debug, Clone, Serialize, thiserror::Error)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum PreflightError {
/// `repo.info()` failed. Captures the underlying cause as a string
/// so the operator log shows whether it's auth, 404, or transport.
#[error("failed to fetch repo info for '{model_id}': {cause}")]
RepoFetchFailed { model_id: String, cause: String },
/// The repo exists but has no recognised weight files.
#[error(
"repo '{model_id}' has no recognised weight files (no .gguf, no .safetensors); \
a tokenizer-only repo cannot be loaded directly"
)]
EmptyRepo { model_id: String },
/// Operator asked for `tensor_parallel > 1` on a GGUF-only repo.
/// The TP path requires safetensors+config for in-situ
/// quantization; GGUF-TP isn't implemented (see CLAUDE.md).
#[error(
"cannot load '{model_id}' with tensor_parallel={tp_size}: repo is GGUF-only \
({} .gguf files); TP requires dense safetensors. {suggestion}",
gguf_quants.len()
)]
TpRequiresSafetensors {
model_id: String,
tp_size: u32,
gguf_quants: Vec<String>,
suggestion: String,
},
/// Operator asked for a GGUF quant whose substring doesn't match
/// any filename in the repo. `nearest` is a best-effort Levenshtein
/// suggestion against the available quant names.
#[error(
"no GGUF file in '{model_id}' matches quant '{requested}'; \
available: {available:?}{}",
nearest.as_ref().map(|n| format!("; did you mean '{n}'?")).unwrap_or_default()
)]
QuantNotFound {
model_id: String,
requested: String,
available: Vec<String>,
nearest: Option<String>,
},
}
/// Run the placement check.
///
/// One network round-trip (`repo.info()`); no blob fetches. Returns
/// `Ok(PlacementPlan)` when the requested combination is feasible, or
/// a structured `PreflightError` describing what's wrong.
pub async fn preflight(api: &Api, spec: &ModelSpec) -> Result<PlacementPlan, PreflightError> {
let repo = api.model(spec.model_id.clone());
let info = repo
.info()
.await
.map_err(|e| PreflightError::RepoFetchFailed {
model_id: spec.model_id.clone(),
cause: format!("{e}"),
})?;
let filenames: Vec<&str> = info.siblings.iter().map(|s| s.rfilename.as_str()).collect();
let format = classify(&filenames);
let tp_size = spec.tensor_parallel.unwrap_or(1);
match (&format, tp_size, spec.quant.as_deref()) {
// No weights at all — nothing to do.
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
model_id: spec.model_id.clone(),
}),
// GGUF-only + TP: not supported. Today's HauhauCS failure.
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
Err(PreflightError::TpRequiresSafetensors {
model_id: spec.model_id.clone(),
tp_size: tp,
gguf_quants: quants.clone(),
suggestion: format!(
"Set tensor_parallel=1 and pick a quant from {quants:?}, \
or use a dense safetensors release of this model."
),
})
}
// GGUF-only + single-GPU: pick the file that matches the
// operator's quant. Empty quant matches the first GGUF.
(SourceFormat::Gguf { quants }, _, requested) => {
let picked = pick_gguf_file(&filenames, requested.unwrap_or(""));
match picked {
Some(fname) => Ok(PlacementPlan {
model_id: spec.model_id.clone(),
format: format.clone(),
tp_size,
picked_quant_file: Some(fname),
}),
None => Err(PreflightError::QuantNotFound {
model_id: spec.model_id.clone(),
requested: requested.unwrap_or("").to_string(),
available: quants.clone(),
nearest: nearest_quant(requested.unwrap_or(""), quants),
}),
}
}
// Dense or mixed: dense path handles both single-GPU and TP.
// The architecture compatibility check stays where it is —
// `check_dense_config_supported` runs once `config.json` is
// on disk, since it needs the parsed JSON.
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
Ok(PlacementPlan {
model_id: spec.model_id.clone(),
format: format.clone(),
tp_size,
picked_quant_file: None,
})
}
}
}
/// Classify a siblings file list into a `SourceFormat`. Pulled out so
/// the unit tests can exercise it against fixture JSON without
/// spinning up an Api.
pub fn classify(filenames: &[&str]) -> SourceFormat {
let mut gguf_quants: Vec<String> = filenames
.iter()
.filter(|f| f.to_lowercase().ends_with(".gguf"))
.map(|f| f.to_lowercase())
.collect();
gguf_quants.sort();
gguf_quants.dedup();
let has_safetensors = filenames.iter().any(|f| f.ends_with(".safetensors"));
let sharded = filenames
.iter()
.any(|f| f.ends_with("model.safetensors.index.json"));
match (has_safetensors, gguf_quants.is_empty()) {
(true, true) => SourceFormat::DenseSafetensors { sharded },
(true, false) => SourceFormat::Mixed { gguf_quants },
(false, false) => SourceFormat::Gguf {
quants: gguf_quants,
},
(false, true) => SourceFormat::Empty,
}
}
/// Mirror of the quant-matching logic in `candle.rs::resolve_files` so
/// preflight picks the same file the downstream loader would. Empty
/// quant returns the first `.gguf` (any quant). Lowercased substring
/// match otherwise.
fn pick_gguf_file(filenames: &[&str], quant_lc: &str) -> Option<String> {
filenames
.iter()
.filter(|f| f.to_lowercase().ends_with(".gguf"))
.find(|f| quant_lc.is_empty() || f.to_lowercase().contains(quant_lc))
.map(|f| f.to_string())
}
/// Best-effort suggestion when the operator's quant name doesn't
/// substring-match any filename. Extracts the quant-ish token from
/// each `.gguf` filename and picks the one with the smallest
/// Levenshtein distance to the requested string. Returns None when
/// the input is empty or no candidates exist.
fn nearest_quant(requested: &str, candidates: &[String]) -> Option<String> {
if requested.is_empty() || candidates.is_empty() {
return None;
}
// Pull the "Q6_K_P"/"IQ4_XS"-ish token out of each filename for a
// fairer comparison. Filenames look like
// `Qwen3.6-27B-Uncensored-HauhauCS-Aggressive-Q6_K_P.gguf`, so the
// quant is the last `-`-separated segment before the extension,
// lowercased.
let tokens: Vec<(String, String)> = candidates
.iter()
.map(|f| (extract_quant_token(f), f.clone()))
.collect();
let req_lc = requested.to_lowercase();
tokens
.into_iter()
.min_by_key(|(token, _)| levenshtein(&req_lc, token))
.map(|(token, _)| token)
}
fn extract_quant_token(filename: &str) -> String {
let stem = filename
.rsplit_once('.')
.map(|(s, _)| s)
.unwrap_or(filename);
let token = stem.rsplit('-').next().unwrap_or(stem);
token.to_lowercase()
}
/// Iterative Levenshtein. Small inputs (quant names are <=12 chars),
/// no need for the `levenshtein` crate.
fn levenshtein(a: &str, b: &str) -> usize {
let a: Vec<char> = a.chars().collect();
let b: Vec<char> = b.chars().collect();
let (m, n) = (a.len(), b.len());
if m == 0 {
return n;
}
if n == 0 {
return m;
}
let mut prev: Vec<usize> = (0..=n).collect();
let mut curr = vec![0usize; n + 1];
for i in 1..=m {
curr[0] = i;
for j in 1..=n {
let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
}
std::mem::swap(&mut prev, &mut curr);
}
prev[n]
}
#[cfg(test)]
mod tests {
use super::*;
fn spec(model_id: &str, tp: Option<u32>, quant: Option<&str>) -> ModelSpec {
ModelSpec {
model_id: model_id.into(),
harness: "candle".into(),
quant: quant.map(String::from),
tensor_parallel: tp,
devices: None,
}
}
#[test]
fn classify_gguf_only() {
let files = [
"README.md",
".gitattributes",
"Qwen3.6-27B-Q6_K_P.gguf",
"Qwen3.6-27B-Q4_K_P.gguf",
];
match classify(&files) {
SourceFormat::Gguf { quants } => {
assert_eq!(quants.len(), 2);
assert!(quants.iter().any(|q| q.contains("q6_k_p")));
}
other => panic!("expected Gguf, got {other:?}"),
}
}
#[test]
fn classify_dense_sharded() {
let files = [
"config.json",
"tokenizer.json",
"model.safetensors.index.json",
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
];
assert_eq!(
classify(&files),
SourceFormat::DenseSafetensors { sharded: true }
);
}
#[test]
fn classify_dense_single_file() {
let files = ["config.json", "tokenizer.json", "model.safetensors"];
assert_eq!(
classify(&files),
SourceFormat::DenseSafetensors { sharded: false }
);
}
#[test]
fn classify_mixed() {
let files = [
"config.json",
"tokenizer.json",
"model.safetensors",
"model-Q4_K_M.gguf",
];
match classify(&files) {
SourceFormat::Mixed { gguf_quants } => {
assert_eq!(gguf_quants, vec!["model-q4_k_m.gguf"]);
}
other => panic!("expected Mixed, got {other:?}"),
}
}
#[test]
fn classify_empty() {
let files = ["README.md", "tokenizer.json"];
assert_eq!(classify(&files), SourceFormat::Empty);
}
#[test]
fn pick_gguf_substring_match() {
let files = ["model-Q4_K_M.gguf", "model-Q6_K.gguf", "model-Q8_0.gguf"];
assert_eq!(
pick_gguf_file(&files, "q6_k"),
Some("model-Q6_K.gguf".into())
);
}
#[test]
fn pick_gguf_empty_returns_first() {
let files = ["model-Q4_K_M.gguf", "model-Q6_K.gguf"];
assert_eq!(pick_gguf_file(&files, ""), Some("model-Q4_K_M.gguf".into()));
}
#[test]
fn pick_gguf_no_match() {
let files = ["model-Q4_K_M.gguf", "model-Q6_K.gguf"];
assert_eq!(pick_gguf_file(&files, "iq2_xs"), None);
}
#[test]
fn nearest_quant_suggests_close_match() {
// Today's HauhauCS scenario: operator wrote "q6k", actual
// filename token is "q6_k_p". Should suggest the latter.
let candidates = vec![
"qwen-q4_k_p.gguf".to_string(),
"qwen-q5_k_p.gguf".to_string(),
"qwen-q6_k_p.gguf".to_string(),
"qwen-q8_k_p.gguf".to_string(),
];
assert_eq!(nearest_quant("q6k", &candidates), Some("q6_k_p".into()));
}
#[test]
fn nearest_quant_empty_input() {
assert_eq!(nearest_quant("", &[]), None);
assert_eq!(nearest_quant("q6k", &[]), None);
assert_eq!(nearest_quant("", &["model-q4.gguf".into()]), None);
}
#[test]
fn extract_quant_handles_typical_filenames() {
assert_eq!(extract_quant_token("Qwen3.6-27B-Q6_K_P.gguf"), "q6_k_p");
assert_eq!(extract_quant_token("model-IQ4_XS.gguf"), "iq4_xs");
assert_eq!(extract_quant_token("simple.gguf"), "simple");
}
#[test]
fn levenshtein_basics() {
assert_eq!(levenshtein("", ""), 0);
assert_eq!(levenshtein("abc", ""), 3);
assert_eq!(levenshtein("", "abc"), 3);
assert_eq!(levenshtein("kitten", "sitting"), 3);
assert_eq!(levenshtein("q6k", "q6_k_p"), 3);
assert_eq!(levenshtein("q6k", "q4_k_p"), 4);
}
// Higher-level preflight tests below exercise the full feasibility
// table via a thin wrapper that bypasses the network — we hand it
// a pre-built `SourceFormat` and request shape, then drive the
// same decision logic. The end-to-end test with a mock HTTP
// server lives in tests/preflight.rs (integration).
/// Mirror of the `match` in `preflight()` but takes a classified
/// `SourceFormat` directly. Lets us unit-test the feasibility
/// table without making the API trait object-safe / boxable.
fn decide(
spec: &ModelSpec,
format: &SourceFormat,
filenames: &[&str],
) -> Result<PlacementPlan, PreflightError> {
let tp_size = spec.tensor_parallel.unwrap_or(1);
match (format, tp_size, spec.quant.as_deref()) {
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
model_id: spec.model_id.clone(),
}),
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
Err(PreflightError::TpRequiresSafetensors {
model_id: spec.model_id.clone(),
tp_size: tp,
gguf_quants: quants.clone(),
suggestion: format!(
"Set tensor_parallel=1 and pick a quant from {quants:?}, \
or use a dense safetensors release of this model."
),
})
}
(SourceFormat::Gguf { quants }, _, requested) => {
let picked = pick_gguf_file(filenames, requested.unwrap_or(""));
match picked {
Some(fname) => Ok(PlacementPlan {
model_id: spec.model_id.clone(),
format: format.clone(),
tp_size,
picked_quant_file: Some(fname),
}),
None => Err(PreflightError::QuantNotFound {
model_id: spec.model_id.clone(),
requested: requested.unwrap_or("").to_string(),
available: quants.clone(),
nearest: nearest_quant(requested.unwrap_or(""), quants),
}),
}
}
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
Ok(PlacementPlan {
model_id: spec.model_id.clone(),
format: format.clone(),
tp_size,
picked_quant_file: None,
})
}
}
}
#[test]
fn feasibility_gguf_tp_rejected() {
let files = ["Qwen-Q6_K_P.gguf", "Qwen-Q4_K_P.gguf"];
let fmt = classify(&files);
let s = spec("HauhauCS/Qwen3.6", Some(2), Some("q6k"));
match decide(&s, &fmt, &files).unwrap_err() {
PreflightError::TpRequiresSafetensors {
model_id,
tp_size,
gguf_quants,
..
} => {
assert_eq!(model_id, "HauhauCS/Qwen3.6");
assert_eq!(tp_size, 2);
assert_eq!(gguf_quants.len(), 2);
}
other => panic!("expected TpRequiresSafetensors, got {other:?}"),
}
}
#[test]
fn feasibility_gguf_single_gpu_bad_quant() {
let files = [
"Qwen-Q4_K_P.gguf",
"Qwen-Q5_K_P.gguf",
"Qwen-Q6_K_P.gguf",
"Qwen-Q8_K_P.gguf",
];
let fmt = classify(&files);
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6k"));
match decide(&s, &fmt, &files).unwrap_err() {
PreflightError::QuantNotFound {
requested,
nearest,
available,
..
} => {
assert_eq!(requested, "q6k");
assert_eq!(nearest.as_deref(), Some("q6_k_p"));
assert_eq!(available.len(), 4);
}
other => panic!("expected QuantNotFound, got {other:?}"),
}
}
#[test]
fn feasibility_gguf_single_gpu_good_quant() {
let files = ["Qwen-Q4_K_M.gguf", "Qwen-Q6_K.gguf"];
let fmt = classify(&files);
let s = spec("Qwen/Q-GGUF", Some(1), Some("q6_k"));
let plan = decide(&s, &fmt, &files).unwrap();
assert_eq!(plan.picked_quant_file.as_deref(), Some("Qwen-Q6_K.gguf"));
}
#[test]
fn feasibility_dense_tp_ok() {
let files = [
"config.json",
"tokenizer.json",
"model.safetensors.index.json",
"model-00001-of-00002.safetensors",
];
let fmt = classify(&files);
let s = spec("Qwen/Q3-30B", Some(2), Some("q5k"));
let plan = decide(&s, &fmt, &files).unwrap();
assert_eq!(plan.tp_size, 2);
assert!(plan.picked_quant_file.is_none());
assert!(matches!(
plan.format,
SourceFormat::DenseSafetensors { sharded: true }
));
}
#[test]
fn feasibility_empty_rejected() {
let files = ["README.md", "tokenizer.json"];
let fmt = classify(&files);
let s = spec("Empty/Repo", Some(1), None);
match decide(&s, &fmt, &files).unwrap_err() {
PreflightError::EmptyRepo { model_id } => assert_eq!(model_id, "Empty/Repo"),
other => panic!("expected EmptyRepo, got {other:?}"),
}
}
#[test]
fn error_serialization_carries_kind_field() {
let err = PreflightError::TpRequiresSafetensors {
model_id: "x/y".into(),
tp_size: 2,
gguf_quants: vec!["q6_k_p".into()],
suggestion: "...".into(),
};
let v: serde_json::Value = serde_json::to_value(&err).unwrap();
assert_eq!(v["kind"], "tp_requires_safetensors");
assert_eq!(v["model_id"], "x/y");
assert_eq!(v["tp_size"], 2);
}
}

View File

@@ -1,119 +0,0 @@
//! `AllReduce` as a candle `CustomOp1` — the bridge between candle's
//! `Tensor` graph and `cudarc::nccl::Comm::all_reduce`.
//!
//! Ported from the canonical
//! `candle-examples/examples/llama_multiprocess/model.rs` pattern.
//! Row-parallel layers apply this op after their local matmul to sum
//! partial outputs across NCCL ranks.
//!
//! Available only under `--features cuda`; on CPU builds this module
//! is empty and row-parallel layers degenerate to local matmul only
//! (useful for compile-checking the model code; correctness requires
//! cuda).
//!
//! Thread-safety caveat: NCCL communicators are technically only
//! safe to use from a single thread at a time
//! (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html).
//! We hold the `AllReduce` behind an `Arc<Comm>` and only issue ops
//! against it from the dedicated `spawn_blocking` thread the inference
//! pipeline already uses for candle's forward passes.
#![cfg(feature = "cuda")]
use candle_core::backend::BackendStorage;
use candle_core::{CpuStorage, CudaStorage, CustomOp1, DType, Layout, Result, Shape};
use cudarc::nccl::{Comm, ReduceOp};
use half::{bf16, f16};
use std::sync::Arc;
/// Wraps an NCCL `Comm` so it can be plugged into a candle forward
/// graph as a custom op. Each row-parallel layer holds one of these.
pub struct AllReduce {
comm: Arc<Comm>,
}
// SAFETY: `Comm` contains a raw `ncclComm_t` pointer; NCCL's docs note
// that issuing ops against one comm from multiple threads concurrently
// is unsafe. We serialise via the single spawn_blocking thread that
// drives the model's forward pass. The Send/Sync impl is necessary
// because candle's CustomOp1 trait bounds require it; the correctness
// invariant is enforced at the call site, not the type level.
unsafe impl Send for AllReduce {}
unsafe impl Sync for AllReduce {}
impl AllReduce {
pub fn new(comm: Arc<Comm>) -> Self {
Self { comm }
}
pub fn comm(&self) -> &Arc<Comm> {
&self.comm
}
}
impl CustomOp1 for AllReduce {
fn name(&self) -> &'static str {
"neuron.tp.all_reduce"
}
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
candle_core::bail!("AllReduce custom-op invoked on CPU storage; TP requires CUDA")
}
fn cuda_fwd(&self, s: &CudaStorage, l: &Layout) -> Result<(CudaStorage, Shape)> {
// Reject non-contiguous inputs explicitly — copying them
// server-side would mask shape bugs (a TP layer feeding a
// strided activation into all_reduce is almost certainly a
// model construction error).
fn require_contiguous<T: cudarc::driver::DeviceRepr>(
slice: &cudarc::driver::CudaSlice<T>,
l: &Layout,
) -> Result<()> {
match l.contiguous_offsets() {
Some((0, n)) if n == slice.len() => Ok(()),
_ => candle_core::bail!(
"AllReduce input is non-contiguous: layout={:?}, slice_len={}",
l,
slice.len()
),
}
}
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let out = match s.dtype() {
DType::BF16 => {
let src = s.as_cuda_slice::<bf16>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce bf16: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F16 => {
let src = s.as_cuda_slice::<f16>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f16: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F32 => {
let src = s.as_cuda_slice::<f32>()?;
require_contiguous(src, l)?;
let mut dst = unsafe { dev.alloc::<f32>(elem_count) }?;
self.comm
.all_reduce(src, &mut dst, &ReduceOp::Sum)
.map_err(|e| candle_core::Error::Msg(format!("nccl all_reduce f32: {e:?}")))?;
CudaStorage::wrap_cuda_slice(dst, dev)
}
dtype => candle_core::bail!(
"AllReduce: unsupported dtype {dtype:?}; TP path expects bf16/f16/f32"
),
};
Ok((out, l.shape().clone()))
}
}

View File

@@ -1,213 +0,0 @@
//! Direct safetensors readers for fused-region weight tensors.
//!
//! Qwen3-Next's `in_proj_qkv` and `conv1d` weights are *fused* —
//! three regions stored sequentially along dim 0 (`[key_q, key_k,
//! value]`). The per-rank shard for each region has unequal size
//! (`key_dim/ws` vs `value_dim/ws`), so candle's `ShardedSafeTensors`
//! built-in `Shard { dim, rank, world_size }` (uniform split) doesn't
//! map to the right slices.
//!
//! The previous approach loaded the full fused tensor onto the device,
//! `narrow`ed the three regions, and `Tensor::cat(...).contiguous()`'d
//! the per-rank slice. That left ~100 MB of transient device memory
//! per linear-attention layer — 48 layers × 100 MB = ~4.8 GB of
//! allocator pressure during load, enough to trigger fragmentation
//! OOM on tight-VRAM consumer GPUs.
//!
//! This module reads the three per-rank byte ranges *directly from
//! the safetensors mmap* (host-side), concatenates them into a single
//! contiguous byte buffer, and uploads as one device allocation. No
//! full-tensor device materialisation.
use anyhow::{Context, Result, bail};
use candle_core::safetensors::MmapedSafetensors;
use candle_core::{DType, Device, Tensor};
/// Read a 2D fused-QKV tensor `[conv_dim, hidden_size]` and return
/// this rank's per-region slice as a `[per_rank_conv_dim, hidden_size]`
/// device tensor.
///
/// `tensor_name` must be the fully-qualified safetensors key (e.g.
/// `"model.language_model.layers.5.linear_attn.in_proj_qkv.weight"`).
#[allow(clippy::too_many_arguments)]
pub fn load_fused_qkv_2d(
mmap: &MmapedSafetensors,
tensor_name: &str,
hidden_size: usize,
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
target_dtype: DType,
device: &Device,
) -> Result<Tensor> {
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused qkv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
let view = mmap
.get(tensor_name)
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 2D"))?;
let view_dtype: DType = view
.dtype()
.try_into()
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
let shape = view.shape();
if shape.len() != 2 {
bail!(
"fused qkv tensor '{tensor_name}' has shape {shape:?}, expected 2D \
[conv_dim, hidden_size]"
);
}
let conv_dim = key_dim * 2 + value_dim;
if shape[0] != conv_dim || shape[1] != hidden_size {
bail!(
"fused qkv tensor '{tensor_name}' shape {shape:?} \
doesn't match expected [{conv_dim}, {hidden_size}]"
);
}
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
let k_bytes = slice_dim0_bytes(
&view,
key_dim + r * per_rank_key,
per_rank_key,
tensor_name,
"k",
)?;
let v_bytes = slice_dim0_bytes(
&view,
2 * key_dim + r * per_rank_value,
per_rank_value,
tensor_name,
"v",
)?;
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
bytes.extend_from_slice(&q_bytes);
bytes.extend_from_slice(&k_bytes);
bytes.extend_from_slice(&v_bytes);
let tensor = Tensor::from_raw_buffer(
&bytes,
view_dtype,
&[per_rank_conv_dim, hidden_size],
device,
)
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused qkv '{tensor_name}'"))?;
tensor
.to_dtype(target_dtype)
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
}
/// Read a 3D fused-QKV tensor `[conv_dim, 1, kernel_size]` (the
/// depthwise conv1d weight) and return this rank's per-region slice
/// as a `[per_rank_conv_dim, 1, kernel_size]` device tensor.
#[allow(clippy::too_many_arguments)]
pub fn load_fused_qkv_3d(
mmap: &MmapedSafetensors,
tensor_name: &str,
mid: usize,
kernel_size: usize,
key_dim: usize,
value_dim: usize,
rank: u32,
world_size: u32,
target_dtype: DType,
device: &Device,
) -> Result<Tensor> {
let ws = world_size as usize;
let r = rank as usize;
if !key_dim.is_multiple_of(ws) || !value_dim.is_multiple_of(ws) {
bail!(
"fused conv shard: key_dim ({key_dim}) and value_dim ({value_dim}) \
must each be divisible by world_size ({ws})"
);
}
let per_rank_key = key_dim / ws;
let per_rank_value = value_dim / ws;
let per_rank_conv_dim = per_rank_key * 2 + per_rank_value;
let view = mmap
.get(tensor_name)
.with_context(|| format!("mmap.get('{tensor_name}') for fused qkv 3D"))?;
let view_dtype: DType = view
.dtype()
.try_into()
.with_context(|| format!("safetensors dtype unsupported for '{tensor_name}'"))?;
let shape = view.shape();
if shape.len() != 3 {
bail!(
"fused conv tensor '{tensor_name}' has shape {shape:?}, expected 3D \
[conv_dim, mid, kernel_size]"
);
}
let conv_dim = key_dim * 2 + value_dim;
if shape[0] != conv_dim || shape[1] != mid || shape[2] != kernel_size {
bail!(
"fused conv tensor '{tensor_name}' shape {shape:?} \
doesn't match expected [{conv_dim}, {mid}, {kernel_size}]"
);
}
let q_bytes = slice_dim0_bytes(&view, r * per_rank_key, per_rank_key, tensor_name, "q")?;
let k_bytes = slice_dim0_bytes(
&view,
key_dim + r * per_rank_key,
per_rank_key,
tensor_name,
"k",
)?;
let v_bytes = slice_dim0_bytes(
&view,
2 * key_dim + r * per_rank_value,
per_rank_value,
tensor_name,
"v",
)?;
let mut bytes = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
bytes.extend_from_slice(&q_bytes);
bytes.extend_from_slice(&k_bytes);
bytes.extend_from_slice(&v_bytes);
let tensor = Tensor::from_raw_buffer(
&bytes,
view_dtype,
&[per_rank_conv_dim, mid, kernel_size],
device,
)
.with_context(|| format!("Tensor::from_raw_buffer for per-rank fused conv '{tensor_name}'"))?;
tensor
.to_dtype(target_dtype)
.with_context(|| format!("cast '{tensor_name}' to {target_dtype:?}"))
}
/// Read `len` consecutive rows along dim 0 starting at `start` from
/// the safetensors view, returning the raw bytes. Wraps the same
/// `view.slice(start..stop)` machinery that candle's
/// `ShardedSafeTensors::get` uses internally.
fn slice_dim0_bytes(
view: &safetensors::tensor::TensorView<'_>,
start: usize,
len: usize,
tensor_name: &str,
region: &str,
) -> Result<Vec<u8>> {
use safetensors::slice::IndexOp;
let stop = start + len;
let iter = view.slice(start..stop).map_err(|e| {
anyhow::anyhow!("slice '{tensor_name}' region {region} ({start}..{stop}): {e:?}")
})?;
Ok(iter.into_iter().flatten().copied().collect())
}

View File

@@ -1,795 +0,0 @@
//! Tensor-parallel inference plumbing.
//!
//! The leader process (the neuron daemon proper) drives one
//! subprocess per non-zero NCCL rank — `tokio::process::Command` on
//! `/proc/self/exe --worker --rank N --tp-size N --cuda-device N` —
//! and talks to each over a newline-delimited JSON RPC channel on
//! the worker's stdin/stdout (see `rpc.rs`).
//!
//! Sub-staging:
//!
//! - **7a-i (this commit):** process lifecycle. `WorkerPool::spawn`
//! forks N workers; `ping` round-trips every worker to confirm
//! they're alive; `shutdown` cleanly drains and reaps. `Init` /
//! `NcclSanityCheck` are stubbed.
//! - **7a-ii:** real NCCL `Comm` setup via `Init`, sanity check via
//! `NcclSanityCheck`. CUDA-gated.
//! - **7b:** TP-aware Qwen3 inference dispatched through the pool.
//! - **7c:** crash detection, streaming SSE, graceful unload.
pub mod all_reduce;
pub mod fused_load;
pub mod nccl_state;
pub mod rpc;
pub mod tp_linear;
pub mod tp_qwen3;
pub mod tp_qwen3_5;
pub mod worker;
use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use rpc::{WorkerRequest, WorkerResponse};
/// Leader-side handle for any TP-loaded model. The pool's
/// `load_dense_shard` dispatches on `config.json#/model_type` to build
/// the right variant; downstream callers (the harness's
/// `chat_completion_tp` path, `generate_step`, `clear_kv_cache`,
/// `unload_model`) all hold this enum and let the variant dispatch
/// determine the concrete forward.
///
/// Variants gated on `cuda` because the underlying TP models hold
/// `Arc<cudarc::nccl::Comm>` references — irrelevant on CPU builds.
#[cfg(feature = "cuda")]
pub enum TpLeaderModel {
Qwen3(tp_qwen3::TpQwen3ForCausalLM),
Qwen3_5(tp_qwen3_5::TpQwen3_5ForCausalLM),
}
#[cfg(feature = "cuda")]
impl TpLeaderModel {
pub fn forward(
&mut self,
input: &candle_core::Tensor,
offset: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
TpLeaderModel::Qwen3(m) => m.forward(input, offset),
TpLeaderModel::Qwen3_5(m) => m.forward(input, offset),
}
}
pub fn clear_kv_cache(&mut self) {
match self {
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
TpLeaderModel::Qwen3_5(m) => m.clear_kv_cache(),
}
}
pub fn device(&self) -> &candle_core::Device {
match self {
TpLeaderModel::Qwen3(m) => m.device(),
TpLeaderModel::Qwen3_5(m) => m.device(),
}
}
}
/// One worker subprocess plus its bidirectional stdio handles.
struct Worker {
rank: u32,
/// Captured so the leader can log "spawned rank N on device M" and
/// future stages can re-issue Init after a CUDA reset. Unused in
/// the Stage 7a-i RPC paths themselves.
#[allow(dead_code)]
cuda_device: u32,
child: Child,
stdin: ChildStdin,
stdout: Lines<BufReader<ChildStdout>>,
}
impl Worker {
/// Send a request and wait for the response. Used for sequenced
/// ops like `Ping` / `Shutdown` where the caller doesn't need to
/// overlap the worker's execution with the leader's.
async fn request(&mut self, req: &WorkerRequest) -> Result<WorkerResponse> {
self.send_only(req).await?;
self.recv_only().await
}
/// Write a request without awaiting its response. Pair with
/// `recv_only` from the caller when leader and worker need to do
/// work concurrently — e.g. during `Init`, where the leader
/// itself calls `Comm::from_rank` on rank 0 in parallel with the
/// workers, then collects `InitOk` after NCCL completes.
async fn send_only(&mut self, req: &WorkerRequest) -> Result<()> {
let mut line = serde_json::to_string(req).context("serialise WorkerRequest")?;
line.push('\n');
self.stdin
.write_all(line.as_bytes())
.await
.with_context(|| format!("write request to rank {}", self.rank))?;
self.stdin
.flush()
.await
.with_context(|| format!("flush stdin to rank {}", self.rank))?;
Ok(())
}
async fn recv_only(&mut self) -> Result<WorkerResponse> {
let reply = self
.stdout
.next_line()
.await
.with_context(|| format!("read reply from rank {}", self.rank))?
.ok_or_else(|| anyhow::anyhow!("rank {} stdout closed before reply", self.rank))?;
serde_json::from_str(&reply)
.with_context(|| format!("parse reply from rank {}: {reply:?}", self.rank))
}
}
/// Drain one response from every worker, classifying each via the
/// supplied checker. Always reads from every worker — even if some
/// fail — so the next call's recv doesn't pick up stale responses
/// from this one (pipe-poisoning was the cause of the
/// "ClearKvCache: expected KvCacheCleared, got GenerateStepOk" class
/// of bugs).
///
/// Returns a vector of `rank N: detail` strings for any worker that
/// errored, expected-mismatched, or failed to respond. Caller decides
/// how to combine these with the leader's outcome.
async fn drain_workers(
workers: &mut [Worker],
mut check: impl FnMut(WorkerResponse) -> std::result::Result<(), String>,
) -> Vec<String> {
let mut errs = Vec::new();
for w in workers {
match w.recv_only().await {
Ok(resp) => {
if let Err(detail) = check(resp) {
errs.push(format!("rank {} {detail}", w.rank));
}
}
Err(e) => errs.push(format!("rank {} recv: {e:#}", w.rank)),
}
}
errs
}
/// Combine a leader's `Result<Result<T>>` (the typical
/// `spawn_blocking → JoinHandle<Result<T>>` shape) with the worker
/// drain results into a single `Result<T>`. Leader failures take
/// precedence in the error message but worker errors get appended so
/// the operator sees both halves.
#[cfg(feature = "cuda")]
fn combine_leader_workers<T>(
leader: Result<Result<T>>,
worker_errors: Vec<String>,
op: &str,
) -> Result<T> {
match leader {
Ok(Ok(value)) => {
if worker_errors.is_empty() {
Ok(value)
} else {
anyhow::bail!(
"{op}: leader succeeded but workers failed: {}",
worker_errors.join("; ")
)
}
}
Ok(Err(e)) => {
if worker_errors.is_empty() {
Err(e.context(format!("{op}: leader forward failed")))
} else {
Err(e.context(format!(
"{op}: leader forward failed and workers also failed: {}",
worker_errors.join("; ")
)))
}
}
Err(panic_err) => {
if worker_errors.is_empty() {
Err(panic_err)
} else {
Err(panic_err.context(format!(
"{op}: leader task panicked and workers failed: {}",
worker_errors.join("; ")
)))
}
}
}
}
/// A live pool of worker subprocesses. Owns the `Child` handles so
/// dropping the pool kills the children; explicit `shutdown()` is
/// the graceful path.
pub struct WorkerPool {
world_size: u32,
workers: Vec<Worker>,
/// Path to the neuron binary used to launch workers.
#[allow(dead_code)]
exe: PathBuf,
/// The leader's per-device CUDA worker thread. Phase 3 moved the
/// leader's `NcclState` (rank-0 NCCL Comm) into this thread, so
/// every NCCL op (init, sanity, all_reduce inside forward) issues
/// from one OS thread for the daemon's lifetime. The handle is
/// also used by `load_dense_shard` to clone the leader's
/// `Arc<Comm>` for the row-parallel layers' AllReduce ops; in
/// Phase 4 the load itself moves onto the worker and that bridge
/// goes away.
pub(crate) leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
}
impl WorkerPool {
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
/// leader (in-process) and is *not* spawned here — the leader
/// holds rank 0's NCCL Comm and shard in its own address space.
///
/// `binary` is the path to the neuron executable to run for each
/// worker (production passes `/proc/self/exe`; tests pass the
/// sibling-binary path from `env!("CARGO_BIN_EXE_neuron")`).
/// `cuda_devices` is one entry per rank including rank 0. Worker
/// `i` (rank `i`) gets `cuda_devices[i]` as its `--cuda-device`.
pub async fn spawn(
binary: &Path,
world_size: u32,
cuda_devices: &[u32],
leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
) -> Result<Self> {
if world_size < 2 {
anyhow::bail!(
"WorkerPool::spawn called with world_size={world_size}; \
use the single-process path for world_size < 2"
);
}
if cuda_devices.len() as u32 != world_size {
anyhow::bail!(
"expected {world_size} cuda_devices entries, got {}",
cuda_devices.len()
);
}
let exe = binary.to_path_buf();
let mut workers = Vec::with_capacity(world_size as usize - 1);
// Rank 0 stays in-process. Spawn ranks 1..world_size.
for rank in 1..world_size {
let cuda_device = cuda_devices[rank as usize];
let mut cmd = Command::new(&exe);
cmd.arg("--worker")
.arg("--rank")
.arg(rank.to_string())
.arg("--tp-size")
.arg(world_size.to_string())
.arg("--cuda-device")
.arg(cuda_device.to_string())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
// Inherit stderr so worker tracing surfaces alongside
// the leader's journalctl stream.
.stderr(Stdio::inherit())
.kill_on_drop(true);
let mut child = cmd
.spawn()
.with_context(|| format!("spawn worker rank {rank}"))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdin handle"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| anyhow::anyhow!("rank {rank}: no stdout handle"))?;
let stdout = BufReader::new(stdout).lines();
workers.push(Worker {
rank,
cuda_device,
child,
stdin,
stdout,
});
tracing::info!(rank, cuda_device, "spawned tp worker");
}
Ok(Self {
world_size,
workers,
exe,
leader_worker,
})
}
/// Establish the NCCL communicator across the leader (rank 0) and
/// every worker subprocess. Rendezvous is via a freshly-generated
/// `Id` broadcast over the RPC stream; the actual handshake blocks
/// inside `Comm::from_rank` until all `world_size` ranks check in.
///
/// `leader_cuda_device` is the CUDA device the leader binds rank 0
/// to — typically the first entry of the `cuda_devices` slice
/// originally passed to `spawn()`.
///
/// On the non-cuda build this immediately fails because the leader
/// can't generate an `Id` without libnccl. The same call works in
/// the worker path (returning a no-cuda error response) so the
/// failure surface is uniform.
pub async fn init_nccl(&mut self, leader_cuda_device: u32) -> Result<()> {
let comm_id = nccl_state::generate_comm_id_hex()
.map_err(|m| anyhow::anyhow!("generate NCCL id: {m}"))?;
// 1. Write Init to every worker's stdin without awaiting the
// response. Workers will parse and call Comm::from_rank
// concurrently with the leader below.
for w in &mut self.workers {
let req = WorkerRequest::Init {
comm_id: comm_id.clone(),
};
w.send_only(&req).await?;
}
// 2. Leader rank 0 calls Comm::from_rank on its own device.
// Phase 3 moved this from spawn_blocking onto the leader's
// device worker thread (`Job::NcclInit`); the underlying
// `Comm` now lives on the same OS thread for its entire
// lifetime, including every later `Comm::all_reduce` issued
// by the row-parallel layers during forward.
//
// NCCL's init blocks until every rank has called in — the
// subprocess workers above and the leader's device worker
// here. The Job's reply unblocks when the leader's
// Comm::from_rank returns.
let leader_cfg = worker::WorkerConfig {
rank: 0,
world_size: self.world_size,
cuda_device: leader_cuda_device,
};
let leader_resp = self
.leader_worker
.nccl_init(leader_cfg, comm_id.clone())
.await
.map_err(|e| anyhow::anyhow!("leader NCCL init via device worker: {e}"))?;
match leader_resp {
rpc::WorkerResponse::InitOk => {}
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("leader rank 0 init failed [{kind}]: {message}");
}
other => anyhow::bail!("leader rank 0 init: unexpected {other:?}"),
}
// 3. Read InitOk from each worker. By now every worker has
// completed its Comm::from_rank call (NCCL released them
// when the leader joined the handshake) and is writing its
// response.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match &resp {
rpc::WorkerResponse::InitOk => {}
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} init failed [{kind}]: {message}", w.rank);
}
other => anyhow::bail!(
"worker rank {} init: expected InitOk, got {other:?}",
w.rank
),
}
}
tracing::info!(
world_size = self.world_size,
"NCCL communicator established across all ranks"
);
Ok(())
}
/// Validate the NCCL communicator: every rank `all_reduce`s a
/// sentinel `1u32` with `ReduceOp::Sum`; the expected total is
/// `world_size`. Confirms the handshake is live, not just
/// configured.
///
/// Must be called after `init_nccl()`; before that the leader has
/// no Comm and the workers reply with `nccl_not_initialised`.
pub async fn nccl_sanity_check(&mut self) -> Result<()> {
// 1. Trigger the all_reduce on every worker (write-only).
for w in &mut self.workers {
w.send_only(&WorkerRequest::NcclSanityCheck).await?;
}
// 2. Leader's own all_reduce, on its device worker thread.
// NCCL operations block until every rank participates;
// Job::NcclSanity returns once the leader's side completes
// (which happens when every subprocess worker reaches its
// all_reduce call too).
let leader_resp = self
.leader_worker
.nccl_sanity()
.await
.map_err(|e| anyhow::anyhow!("leader NCCL sanity via device worker: {e}"))?;
let expected = self.world_size;
let leader_sum = match leader_resp {
rpc::WorkerResponse::NcclSanityResult { observed_sum } => observed_sum,
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("leader rank 0 sanity failed [{kind}]: {message}");
}
other => anyhow::bail!("leader rank 0 sanity: unexpected {other:?}"),
};
if leader_sum != expected {
anyhow::bail!("leader observed_sum={leader_sum}, expected {expected}");
}
// 3. Read sanity result from each worker. All must match
// world_size — anything else means the collective didn't
// complete consistently across ranks.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
rpc::WorkerResponse::NcclSanityResult { observed_sum }
if observed_sum == expected => {}
rpc::WorkerResponse::NcclSanityResult { observed_sum } => {
anyhow::bail!(
"worker rank {} observed_sum={observed_sum}, expected {expected}",
w.rank
);
}
rpc::WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} sanity failed [{kind}]: {message}", w.rank);
}
other => anyhow::bail!("worker rank {} sanity: unexpected {other:?}", w.rank),
}
}
tracing::info!(
world_size = expected,
"NCCL sanity check OK across all ranks"
);
Ok(())
}
/// Ping every worker and return their Pong payloads in rank order.
/// Useful right after `spawn` to confirm the lifecycle plumbing is
/// intact before kicking off any heavier work.
pub async fn ping_all(&mut self) -> Result<Vec<WorkerResponse>> {
let mut out = Vec::with_capacity(self.workers.len());
for w in &mut self.workers {
let resp = w.request(&WorkerRequest::Ping).await?;
match &resp {
WorkerResponse::Pong { rank, .. } if *rank == w.rank => {}
WorkerResponse::Pong { rank, .. } => {
anyhow::bail!("rank mismatch: expected {}, got {rank}", w.rank);
}
other => anyhow::bail!("expected Pong from rank {}, got {other:?}", w.rank),
}
out.push(resp);
}
Ok(out)
}
/// Load this rank's shard of a dense Qwen3 model on every rank.
///
/// The leader builds rank 0's `TpQwen3ForCausalLM` directly into
/// the returned `Arc<Mutex<_>>` — workers build their rank-local
/// shards in their own address spaces and confirm via
/// `LoadDenseShardOk`. All ranks see the same `safetensors_paths`;
/// `ShardedVarBuilder` slices each tensor by rank at materialisation
/// time, so the per-rank VRAM footprint is roughly `1/world_size`
/// of the full model (plus the replicated embedding/norm/lm_head).
///
/// `leader_device` is the candle `Device` the leader's shard lives
/// on — typically `Device::new_cuda(leader_cuda_device)` matching
/// the same index passed to `init_nccl`. `dtype` is the on-device
/// element type; bf16 is the canonical Qwen3 distribution dtype.
///
/// `init_nccl` must have completed first. Bails if the leader's
/// NCCL comm isn't set up yet.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn load_dense_shard(
&mut self,
model_id: &str,
config_json: &str,
safetensors_paths: &[std::path::PathBuf],
_leader_device: &candle_core::Device,
dtype: candle_core::DType,
quant: Option<String>,
) -> Result<super::device_worker::TpHandle> {
let world_size = self.world_size;
let safetensors_str: Vec<String> = safetensors_paths
.iter()
.map(|p| p.to_string_lossy().into_owned())
.collect();
// 1. Fan out the LoadDenseShard request to every subprocess
// worker without awaiting their replies — they'll build
// their shards in parallel with the leader below.
for w in &mut self.workers {
w.send_only(&WorkerRequest::LoadDenseShard {
model_id: model_id.to_string(),
config_json: config_json.to_string(),
safetensors_paths: safetensors_str.clone(),
quant: quant.clone(),
})
.await?;
}
// 2. Build rank 0's shard on the leader's device worker
// thread. Phase 4 moved the load itself onto the worker —
// the dispatch handler reads `state.nccl.comm()` directly
// so the leader's `Arc<Comm>` clones embedded in the
// row-parallel layers are constructed and used on the same
// OS thread for the model's entire lifetime. No
// spawn_blocking, no SendComm bridge.
let handle = self
.leader_worker
.tp_load_shard(
model_id.to_string(),
config_json.to_string(),
safetensors_paths.to_vec(),
dtype,
quant.clone(),
world_size,
)
.await
.map_err(|e| anyhow::anyhow!("leader TP shard load via device worker: {e}"))?;
// 3. Collect worker confirmations. Anything other than
// LoadDenseShardOk aborts the whole load — the leader's
// already-inserted shard would leak in the worker slab
// until the daemon restarts; an explicit DropTp would be
// cleaner but the failure here is rare and the operator's
// next step is to restart anyway.
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
WorkerResponse::LoadDenseShardOk => {}
WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} LoadDenseShard [{kind}]: {message}", w.rank)
}
other => anyhow::bail!(
"worker rank {} LoadDenseShard: expected LoadDenseShardOk, got {other:?}",
w.rank
),
}
}
Ok(handle)
}
/// Run one forward step across every rank. The leader's forward
/// runs on the device worker thread via `Job::TpForwardLogits` and
/// returns CPU-side `[vocab]` logits as `Vec<f32>`; the async
/// caller wraps them in a CPU tensor for `apply_repeat_penalty` +
/// sampling without holding a device-resident tensor on a tokio
/// thread.
///
/// Subprocess workers run their own forwards in parallel (the
/// AllReduce CustomOps inside row-parallel layers are what let
/// the leader's collective complete) and reply with
/// `GenerateStepOk` over the RPC stream — they do not ship logits.
///
/// `tokens` is the input for this step (prompt for prefill, the
/// previously-sampled token for decode). `offset` is the KV-cache
/// position before this step.
#[cfg(feature = "cuda")]
pub async fn generate_step(
&mut self,
model_id: &str,
leader_handle: super::device_worker::TpHandle,
tokens: Vec<u32>,
offset: usize,
) -> Result<Vec<f32>> {
let step_start = std::time::Instant::now();
let tokens_len = tokens.len();
tracing::debug!(
model = %model_id,
tokens = tokens_len,
offset,
"WorkerPool::generate_step: fan-out"
);
// 1. Fan-out to subprocess workers.
for w in &mut self.workers {
w.send_only(&WorkerRequest::GenerateStep {
model_id: model_id.to_string(),
tokens: tokens.clone(),
offset,
})
.await?;
}
// 2. Leader's forward on its device worker thread. The
// AllReduce CustomOps inside the row-parallel layers block
// until every subprocess worker's forward issues the
// matching collective. Returning CPU-side `Vec<f32>` keeps
// the device tensor from escaping the worker thread —
// that's the invariant the whole refactor exists to
// preserve.
let leader_start = std::time::Instant::now();
let leader_result = self
.leader_worker
.tp_forward_logits(leader_handle, tokens, offset)
.await;
let leader_ok = leader_result.is_ok();
let leader_ms = leader_start.elapsed().as_millis();
// Surface the leader's own error at WARN before draining
// workers so the operator can correlate it with whatever the
// subprocess workers logged. Previously this was silently
// coerced to a bool.
if !leader_ok {
let detail = leader_result
.as_ref()
.err()
.map(|e| format!("{e:#}"))
.unwrap_or_default();
tracing::warn!(
model = %model_id,
tokens = tokens_len,
offset,
leader_ms,
error = %detail,
"WorkerPool::generate_step: leader forward failed"
);
}
tracing::debug!(
model = %model_id,
tokens = tokens_len,
leader_ms,
leader_ok,
"WorkerPool::generate_step: leader forward returned"
);
// 3. ALWAYS drain worker responses, regardless of whether the
// leader succeeded. Skipping this on the leader's error
// path leaves stale GenerateStepOk replies in the worker
// pipes that poison the NEXT request's recv (was seeing
// "ClearKvCache: expected KvCacheCleared, got
// GenerateStepOk" the call after any forward-time failure).
let drain_start = std::time::Instant::now();
let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::GenerateStepOk => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
other => Err(format!("expected GenerateStepOk, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
drain_ms = drain_start.elapsed().as_millis(),
errors = worker_errors.len(),
total_ms = step_start.elapsed().as_millis(),
"WorkerPool::generate_step: workers drained"
);
// Combine the leader's Result + the workers' string-error
// list. Phase 3 inlines this because the upstream
// `combine_leader_workers` expects the spawn_blocking-shaped
// `Result<Result<T>>`; the new device-worker path produces a
// single `Result<T, WorkerError>` instead.
match leader_result {
Ok(values) => {
if worker_errors.is_empty() {
Ok(values)
} else {
anyhow::bail!(
"GenerateStep: leader succeeded but workers failed: {}",
worker_errors.join("; ")
)
}
}
Err(e) => {
if worker_errors.is_empty() {
Err(anyhow::Error::new(e).context("GenerateStep: leader forward failed"))
} else {
Err(anyhow::Error::new(e).context(format!(
"GenerateStep: leader forward failed and workers also failed: {}",
worker_errors.join("; ")
)))
}
}
}
}
/// Reset the KV cache for `model_id` on every rank. Called at the
/// start of every inference so a fresh request doesn't attend over
/// the previous one's tokens.
pub async fn clear_kv_cache(
&mut self,
model_id: &str,
#[cfg(feature = "cuda")] leader_handle: super::device_worker::TpHandle,
) -> Result<()> {
let start = std::time::Instant::now();
tracing::debug!(model = %model_id, "WorkerPool::clear_kv_cache: fan-out");
for w in &mut self.workers {
w.send_only(&WorkerRequest::ClearKvCache {
model_id: model_id.to_string(),
})
.await?;
}
#[cfg(feature = "cuda")]
{
// Leader-side clear on the device worker thread —
// `TpLeaderModel::clear_kv_cache` is infallible but still
// routes through Job::TpClearKv so the cache reset runs
// on the same thread that owns the model's CUDA tensors.
if let Err(e) = self.leader_worker.tp_clear_kv(leader_handle).await {
anyhow::bail!("leader TP clear_kv_cache via device worker: {e}");
}
}
// Drain workers — same rationale as `generate_step`. The
// leader's clear_kv_cache is now async-via-channel but still
// returns before the drain so the workers' KvCacheCleared
// replies are processed in order.
let worker_errors = drain_workers(&mut self.workers, |r| match r {
WorkerResponse::KvCacheCleared => Ok(()),
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
other => Err(format!("expected KvCacheCleared, got {other:?}")),
})
.await;
tracing::debug!(
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
errors = worker_errors.len(),
"WorkerPool::clear_kv_cache: workers drained"
);
if !worker_errors.is_empty() {
anyhow::bail!("ClearKvCache: {}", worker_errors.join("; "));
}
Ok(())
}
/// Drop this model's shards on every rank. The leader's shard is
/// expected to have been dropped by the caller (its `Arc` was held
/// in the TpLoadedModel and goes away when that's removed).
pub async fn unload_model(&mut self, model_id: &str) -> Result<()> {
for w in &mut self.workers {
w.send_only(&WorkerRequest::UnloadModel {
model_id: model_id.to_string(),
})
.await?;
}
for w in &mut self.workers {
let resp = w.recv_only().await?;
match resp {
WorkerResponse::Unloaded => {}
WorkerResponse::Error { kind, message } => {
anyhow::bail!("worker rank {} UnloadModel [{kind}]: {message}", w.rank)
}
other => anyhow::bail!(
"worker rank {} UnloadModel: expected Unloaded, got {other:?}",
w.rank
),
}
}
Ok(())
}
/// Send `Shutdown` to every worker, await each `Bye`, and reap the
/// children. Best-effort — individual worker failures are logged
/// but don't abort the rest of the sweep.
pub async fn shutdown(mut self) -> Result<()> {
for w in &mut self.workers {
match w.request(&WorkerRequest::Shutdown).await {
Ok(WorkerResponse::Bye) => {}
Ok(other) => tracing::warn!(
rank = w.rank,
response = ?other,
"expected Bye on shutdown"
),
Err(e) => tracing::warn!(rank = w.rank, error = %e, "shutdown request failed"),
}
}
for w in &mut self.workers {
match w.child.wait().await {
Ok(status) => tracing::info!(rank = w.rank, %status, "worker exited"),
Err(e) => tracing::warn!(rank = w.rank, error = %e, "wait on worker failed"),
}
}
Ok(())
}
pub fn world_size(&self) -> u32 {
self.world_size
}
pub fn binary_path(&self) -> &PathBuf {
&self.exe
}
}

View File

@@ -1,293 +0,0 @@
//! NCCL state held by both the worker process and the leader's pool.
//!
//! Split into its own module so the worker (`tp/worker.rs`) and the
//! leader (`tp/mod.rs`) share the same hex-encoding/decoding code and
//! the same shape of `Option<Comm>` state machine.
//!
//! When the `cuda` feature is off, `NcclState` is a zero-sized
//! placeholder that returns `Error{kind="cuda_feature_not_enabled"}`
//! from every operation. When it's on, the same struct holds the
//! actual `cudarc::nccl::Comm`.
use super::rpc::WorkerResponse;
use super::worker::WorkerConfig;
/// Encode bytes as lowercase hex. Used for ferrying NCCL `Id::internal()`
/// across the leader→worker RPC boundary inside a JSON string.
pub fn encode_hex(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len() * 2);
for b in bytes {
use std::fmt::Write;
let _ = write!(out, "{b:02x}");
}
out
}
/// Decode lowercase-or-uppercase hex into bytes. Errors on odd length
/// or non-hex characters; the caller bubbles those up via the RPC's
/// `Error{kind="bad_request"}` variant.
pub fn decode_hex(s: &str) -> Result<Vec<u8>, String> {
if !s.len().is_multiple_of(2) {
return Err(format!("hex string has odd length {}", s.len()));
}
(0..s.len())
.step_by(2)
.map(|i| {
u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("bad hex byte at {i}: {e}"))
})
.collect()
}
#[cfg(not(feature = "cuda"))]
pub struct NcclState;
#[cfg(not(feature = "cuda"))]
impl Default for NcclState {
fn default() -> Self {
Self::new()
}
}
#[cfg(not(feature = "cuda"))]
impl NcclState {
pub fn new() -> Self {
Self
}
pub fn init(&mut self, _cfg: WorkerConfig, _comm_id_hex: &str) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "this neuron binary was built without --features cuda; \
NCCL Init requires CUDA"
.into(),
}
}
pub fn sanity_check(&mut self) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "NCCL sanity check requires --features cuda".into(),
}
}
}
#[cfg(feature = "cuda")]
mod cuda_impl {
use super::*;
use cudarc::driver::CudaContext;
use cudarc::nccl::{Comm, Id, ReduceOp};
use std::sync::Arc;
/// Number of bytes in NCCL's unique-id type; matches `Id::internal()`'s
/// `[c_char; 128]`. Wire-encoded as 256 lowercase hex chars.
const NCCL_ID_BYTES: usize = 128;
pub struct NcclState {
/// Wrapped in `Arc` so we can hand a clone to `TpQwen3ForCausalLM`
/// at load time (every row-parallel layer needs a reference to
/// run its trailing `AllReduce`). The `Arc` is the single source
/// of truth for the comm's lifetime — when the pool drops and
/// every layer that captured a clone drops, NCCL releases the
/// underlying `ncclComm_t`.
comm: Option<Arc<Comm>>,
/// Held alongside the Comm so the device isn't dropped
/// underneath the NCCL handle.
#[allow(dead_code)]
ctx: Option<Arc<CudaContext>>,
}
impl Default for NcclState {
fn default() -> Self {
Self::new()
}
}
impl NcclState {
pub fn new() -> Self {
Self {
comm: None,
ctx: None,
}
}
/// Clone the comm out as an `Arc` so callers (the leader-side
/// `TpQwen3ForCausalLM::load`, or the worker's own model load)
/// can hold a reference for the lifetime of the model. Returns
/// `None` before `init` has run.
pub fn comm(&self) -> Option<Arc<Comm>> {
self.comm.clone()
}
}
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
/// given comm must be serialised", not "the handle must stay on the
/// thread that created it" — so it's safe to move an `Arc<Comm>`
/// across threads as long as no concurrent ops are issued. The
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
/// wrapper at the move boundary is the only thing missing.
///
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
/// by the row-parallel layers are only used from the
/// `spawn_blocking` thread driving the forward pass; concurrent
/// access from another thread would still be a bug.
pub struct SendComm(pub Arc<Comm>);
// SAFETY: see the doc-comment above; the invariant is enforced at
// the call site (pool Mutex + single spawn_blocking thread), not at
// the type level.
unsafe impl Send for SendComm {}
unsafe impl Sync for SendComm {}
impl SendComm {
pub fn into_inner(self) -> Arc<Comm> {
self.0
}
}
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
// (libnccl-allocated state). NCCL requires that operations against
// one Comm be issued one at a time; we serialise access by storing
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
// move-safe — NCCL doesn't track the calling OS thread, only the
// stream the operations are dispatched against.
unsafe impl Send for NcclState {}
unsafe impl Sync for NcclState {}
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
/// the leader to mint the shared communicator id which is then
/// broadcast to every worker via the RPC `Init` message.
pub fn generate_comm_id_hex() -> Result<String, String> {
// NcclError lacks a Display impl in cudarc 0.19.x — surface
// via Debug throughout this module.
let id = Id::new().map_err(|e| format!("Id::new(): {e:?}"))?;
let bytes_u8: [u8; NCCL_ID_BYTES] = std::array::from_fn(|i| id.internal()[i] as u8);
Ok(encode_hex(&bytes_u8))
}
impl NcclState {
pub fn init(&mut self, cfg: WorkerConfig, comm_id_hex: &str) -> WorkerResponse {
match try_init(self, cfg, comm_id_hex) {
Ok(()) => WorkerResponse::InitOk,
Err(msg) => WorkerResponse::Error {
kind: "nccl_init_failed".into(),
message: msg,
},
}
}
pub fn sanity_check(&mut self) -> WorkerResponse {
let Some(comm) = self.comm.as_ref() else {
return WorkerResponse::Error {
kind: "nccl_not_initialised".into(),
message: "sanity_check requires Init to have completed first".into(),
};
};
match try_sanity_check(comm.as_ref()) {
Ok(sum) => WorkerResponse::NcclSanityResult { observed_sum: sum },
Err(msg) => WorkerResponse::Error {
kind: "nccl_sanity_failed".into(),
message: msg,
},
}
}
}
fn try_init(state: &mut NcclState, cfg: WorkerConfig, comm_id_hex: &str) -> Result<(), String> {
let bytes = decode_hex(comm_id_hex)?;
if bytes.len() != NCCL_ID_BYTES {
return Err(format!(
"comm_id is {} bytes, expected {NCCL_ID_BYTES}",
bytes.len()
));
}
let id_bytes: [std::ffi::c_char; NCCL_ID_BYTES] =
std::array::from_fn(|i| bytes[i] as std::ffi::c_char);
let id = Id::uninit(id_bytes);
let ctx = CudaContext::new(cfg.cuda_device as usize)
.map_err(|e| format!("CudaContext::new({}) failed: {e}", cfg.cuda_device))?;
let stream = ctx.default_stream();
let comm = Comm::from_rank(stream, cfg.rank as usize, cfg.world_size as usize, id)
.map_err(|e| {
format!(
"Comm::from_rank(rank={}, world={}) failed: {e:?}",
cfg.rank, cfg.world_size
)
})?;
state.ctx = Some(ctx);
// `Comm` is !Send + !Sync at the type level because it wraps a
// raw `ncclComm_t`. The `Arc` is fine in practice — we
// serialise operations through the pool's outer Mutex and the
// SendComm wrapper at thread-crossing boundaries enforces this
// at every move site. clippy's `arc_with_non_send_sync` lint
// can't see that invariant; allow once at the canonical
// construction site.
#[allow(clippy::arc_with_non_send_sync)]
{
state.comm = Some(Arc::new(comm));
}
Ok(())
}
fn try_sanity_check(comm: &Comm) -> Result<u32, String> {
let stream = comm.stream().clone();
let input = stream
.clone_htod(&[1u32])
.map_err(|e| format!("htod sentinel: {e}"))?;
let mut output = stream
.alloc_zeros::<u32>(1)
.map_err(|e| format!("alloc output: {e}"))?;
// cudarc::nccl::NcclError doesn't impl Display in 0.19.x —
// surface via Debug so we still see the variant + ncclResult
// code instead of a generic "{e}" failure.
comm.all_reduce(&input, &mut output, &ReduceOp::Sum)
.map_err(|e| format!("all_reduce: {e:?}"))?;
let result = stream
.clone_dtoh(&output)
.map_err(|e| format!("dtoh result: {e}"))?;
Ok(result[0])
}
}
#[cfg(feature = "cuda")]
pub use cuda_impl::{NcclState, SendComm, generate_comm_id_hex};
/// Non-cuda stub for the leader: returns a clear marker error rather
/// than letting `init_nccl` succeed vacuously.
#[cfg(not(feature = "cuda"))]
pub fn generate_comm_id_hex() -> Result<String, String> {
Err("cuda_feature_not_enabled: build with --features cuda".into())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hex_roundtrip() {
let original: Vec<u8> = (0u8..=255).collect();
let encoded = encode_hex(&original);
assert_eq!(encoded.len(), 512);
let decoded = decode_hex(&encoded).expect("decode");
assert_eq!(decoded, original);
}
#[test]
fn hex_decode_rejects_odd_length() {
assert!(decode_hex("a").is_err());
assert!(decode_hex("abc").is_err());
}
#[test]
fn hex_decode_rejects_non_hex() {
assert!(decode_hex("zz").is_err());
assert!(decode_hex("ab_d").is_err());
}
#[test]
fn hex_encode_is_lowercase_padded() {
assert_eq!(encode_hex(&[0x0a, 0xff]), "0aff");
}
}

View File

@@ -1,257 +0,0 @@
//! Wire protocol between the neuron leader process and its
//! `--worker` subprocesses.
//!
//! Every frame is one newline-delimited JSON object on the worker's
//! stdin (request) or stdout (response). Both directions are tagged
//! sum types from the start so new ops in Stage 7b/7c slot in without
//! breaking compatibility — no "14 message types and a version field"
//! drift later. Adding a new variant is the canonical way to evolve
//! the protocol; existing peers that don't recognise an op return
//! `WorkerResponse::Error { kind: "unknown_op", .. }`.
//!
//! The serialised shape uses `tag = "op"` so a request looks like:
//! {"op":"ping"}
//! {"op":"init","comm_id":"a1b2..."}
//! and a response:
//! {"op":"pong","rank":0,"world_size":2,"cuda_device":0}
//! {"op":"error","kind":"nccl_init_failed","message":"..."}
use serde::{Deserialize, Serialize};
/// Leader → worker. Worker handles one at a time; replies with exactly
/// one `WorkerResponse` per request.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum WorkerRequest {
/// Liveness probe. Worker replies with `Pong` containing its own
/// identity. Used by the leader to confirm the subprocess is up
/// and ready before kicking off any heavier work.
Ping,
/// One-shot NCCL communicator setup. The leader generates the
/// `comm_id` once (rank 0 of NCCL), broadcasts it to every worker
/// via this message, then every rank (leader included) calls
/// `Comm::from_rank` with the same id — NCCL blocks until all
/// `world_size` ranks check in. The hex-encoded bytes are the
/// canonical `cudarc::nccl::Id::internal()` content.
Init {
/// Hex-encoded NCCL id bytes (128 bytes → 256 hex chars).
comm_id: String,
},
/// Sanity check: after Init, every rank runs an `all_reduce` over
/// a sentinel value (`1u32`). The expected sum is `world_size`.
/// Worker replies with the observed value so the leader can verify
/// the NCCL handshake is genuinely live, not just configured.
NcclSanityCheck,
/// Load this rank's shard of a dense Qwen3 model from mmaped
/// safetensors. The same `safetensors_paths` list is sent to every
/// rank — the ShardedVarBuilder reads only the rank-local slice of
/// each tensor at materialisation time, so the worker's VRAM
/// footprint is `1 / world_size` of the full model (plus replicated
/// embedding/norm/lm_head).
LoadDenseShard {
/// Caller-supplied id for later `GenerateStep` / `UnloadModel`
/// lookups. Typically the HF model id verbatim.
model_id: String,
/// JSON-serialised `candle_transformers::models::qwen3::Config`
/// — the same blob the leader parsed from the HF cache's
/// `config.json`. Threaded through verbatim so the worker uses
/// identical hyperparameters.
config_json: String,
/// Absolute paths the worker should mmap. The same set on every
/// rank; ShardedVarBuilder slices into them per rank.
safetensors_paths: Vec<String>,
/// Optional in-situ quantization dtype (e.g. "q5k", "q8_0",
/// "q6k"). When set, each linear-layer weight is quantized
/// at load time to the named ggml format — saves ~3-5x vs
/// bf16/f16 at the cost of some accuracy. `None` keeps the
/// weights in the on-disk dtype (typically bf16).
#[serde(default)]
quant: Option<String>,
},
/// Run one forward step on this rank's loaded model. The worker
/// reaches into its NCCL Comm for the row-parallel `AllReduce`s
/// inside the model — and so blocks on every other rank issuing the
/// same op. The leader does *not* receive logits back over RPC; it
/// runs its own rank-0 forward in parallel and uses its own logits
/// for sampling.
GenerateStep {
model_id: String,
/// Input token ids for this step. For prefill, the whole prompt;
/// for decode, a single token. Identical on every rank.
tokens: Vec<u32>,
/// KV cache offset (count of tokens already in the cache before
/// this step).
offset: usize,
},
/// Reset the KV cache for this model on this rank. Sent at the
/// start of every inference so a fresh request doesn't accidentally
/// attend over the previous one's tokens.
ClearKvCache { model_id: String },
/// Drop this rank's shard for the given model. Releases the VRAM
/// the shard's weights occupied; subsequent `GenerateStep` calls
/// against the same `model_id` return an `Error`.
UnloadModel { model_id: String },
/// Worker should release resources and exit. Worker replies `Bye`
/// and then closes stdout / exits zero. The leader reaps the
/// child via the `tokio::process::Child` it kept.
Shutdown,
}
/// Worker → leader. Always exactly one of these per `WorkerRequest`.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "snake_case")]
pub enum WorkerResponse {
/// Reply to `Ping`. Carries enough identity for the leader to log
/// what it actually got back.
Pong {
rank: u32,
world_size: u32,
cuda_device: u32,
},
/// Reply to `Init`. Empty payload — success is the absence of
/// `Error`. NCCL's internal blocking handshake means by the time
/// this comes back, every other rank has also reached
/// `Comm::from_rank`.
InitOk,
/// Reply to `NcclSanityCheck`. The observed sum after a single
/// `all_reduce(SUM, 1u32)` across all ranks. The leader checks
/// this matches `world_size`.
NcclSanityResult { observed_sum: u32 },
/// Reply to `LoadDenseShard`. Empty payload — success is the
/// absence of `Error`. By the time this comes back, the rank's
/// `TpQwen3ForCausalLM` is constructed in memory and ready for
/// `GenerateStep`.
LoadDenseShardOk,
/// Reply to `GenerateStep`. Empty payload — workers don't ship
/// logits over the wire. The leader uses its own rank-0 logits;
/// workers only need to confirm the collective completed.
GenerateStepOk,
/// Reply to `ClearKvCache`. Empty payload.
KvCacheCleared,
/// Reply to `UnloadModel`. Empty payload. The named model is no
/// longer present on this rank.
Unloaded,
/// Reply to `Shutdown`. Worker exits immediately after writing this.
Bye,
/// Any request can produce this instead of its dedicated success
/// variant. `kind` is a machine-readable category so the leader
/// can branch on failure mode without string-matching `message`.
Error {
/// Short tag — `nccl_init_failed`, `unknown_op`, etc.
kind: String,
/// Human-readable detail for logs.
message: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrip<T>(value: &T) -> T
where
T: Serialize + for<'de> Deserialize<'de>,
{
serde_json::from_str(&serde_json::to_string(value).expect("serialise"))
.expect("deserialise")
}
#[test]
fn request_ping_round_trip() {
let req = WorkerRequest::Ping;
let wire = serde_json::to_string(&req).unwrap();
assert_eq!(wire, r#"{"op":"ping"}"#);
match roundtrip(&req) {
WorkerRequest::Ping => {}
other => panic!("expected Ping, got {other:?}"),
}
}
#[test]
fn request_init_carries_hex_id() {
let req = WorkerRequest::Init {
comm_id: "deadbeef".into(),
};
let wire = serde_json::to_string(&req).unwrap();
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
}
#[test]
fn request_shutdown_round_trip() {
assert_eq!(
serde_json::to_string(&WorkerRequest::Shutdown).unwrap(),
r#"{"op":"shutdown"}"#
);
}
#[test]
fn response_pong_round_trip() {
let resp = WorkerResponse::Pong {
rank: 1,
world_size: 4,
cuda_device: 1,
};
let wire = serde_json::to_string(&resp).unwrap();
assert!(wire.contains(r#""op":"pong""#));
assert!(wire.contains(r#""rank":1"#));
assert!(wire.contains(r#""world_size":4"#));
match roundtrip(&resp) {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
assert_eq!(rank, 1);
assert_eq!(world_size, 4);
assert_eq!(cuda_device, 1);
}
other => panic!("expected Pong, got {other:?}"),
}
}
#[test]
fn response_error_carries_kind_and_message() {
let resp = WorkerResponse::Error {
kind: "nccl_init_failed".into(),
message: "could not bind device".into(),
};
let wire = serde_json::to_string(&resp).unwrap();
assert!(wire.contains(r#""op":"error""#));
assert!(wire.contains(r#""kind":"nccl_init_failed""#));
}
#[test]
fn response_sanity_result_round_trip() {
let resp = WorkerResponse::NcclSanityResult { observed_sum: 4 };
match roundtrip(&resp) {
WorkerResponse::NcclSanityResult { observed_sum } => {
assert_eq!(observed_sum, 4);
}
other => panic!("expected NcclSanityResult, got {other:?}"),
}
}
/// Unknown ops on the wire deserialise to an error rather than
/// silently mis-matching — confirms our `serde(tag = "op")`
/// configuration rejects unknowns instead of doing fuzzy matching.
#[test]
fn unknown_op_fails_to_parse() {
let result: Result<WorkerRequest, _> = serde_json::from_str(r#"{"op":"explode"}"#);
assert!(result.is_err(), "should reject unknown op, got {result:?}");
}
}

View File

@@ -1,283 +0,0 @@
//! Tensor-parallel linear layers built on candle's `ShardedVarBuilder`
//! and `Shard` sharding hints.
//!
//! candle reads only the rank's slice of each weight tensor from
//! safetensors via `view.slice(start..stop)` — no full-tensor host
//! materialisation. That's a memory-efficiency win over hand-rolled
//! "load full + narrow" sharding (which the earlier
//! `sharded_linear.rs` exploration demonstrated but didn't pay for).
//!
//! Two layer types:
//!
//! - [`ColumnParallelLinear`] — output-sharded; forward is a plain
//! local matmul. The downstream consumer either accepts a sharded
//! activation (next layer is also column-parallel) or all-gathers.
//! - [`RowParallelLinear`] — input-sharded; forward = local matmul
//! then `AllReduce` `CustomOp1` to sum partials across ranks.
//!
//! Both assume **no bias** — every Qwen3-family weight layout we
//! actually target (Qwen3, Qwen3-Coder, Qwen3.6 base, etc.) sets
//! `attention_bias=false` and the MLP layers are no-bias. Adding bias
//! support is mechanical when a future model needs it; the design
//! choice would be: column-parallel shards the bias along dim 0;
//! row-parallel holds the bias only on rank 0 so the post-`AllReduce`
//! sum carries it exactly once.
use anyhow::{Context, Result};
use candle_core::quantized::{GgmlDType, QMatMul, QTensor};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::{Shard, ShardedVarBuilder};
use std::sync::Arc;
#[cfg(feature = "cuda")]
use super::all_reduce::AllReduce;
/// Linear primitive that holds either a plain `Linear` (bf16/f16/f32)
/// or a quantized `QMatMul` (Q4K/Q5K/Q6K/Q8_0/etc.).
///
/// Constructed via [`MaybeQuantLinear::from_weight`] — pass `None` to
/// keep the weight in its loaded dtype (no quantization), or
/// `Some(dtype)` to quantize at load time.
///
/// On the forward path the two arms dispatch identically: `Module::forward`
/// returns an output in the caller's input dtype (f32 fallback for the
/// quantized matmul). Subsequent ops don't need to know whether the
/// layer was quantized.
pub enum MaybeQuantLinear {
Plain(Linear),
Quant(QMatMul),
}
impl MaybeQuantLinear {
/// Build a linear from a loaded weight tensor. If `quant` is set,
/// the weight is quantized in-situ and stored as a `QMatMul`;
/// otherwise it's wrapped in a plain `Linear`.
pub fn from_weight(weight: Tensor, quant: Option<GgmlDType>) -> Result<Self> {
match quant {
Some(dtype) => {
let qt = QTensor::quantize(&weight, dtype).with_context(|| {
format!(
"QTensor::quantize to {dtype:?} for shape {:?}",
weight.shape()
)
})?;
let qmm = QMatMul::from_arc(Arc::new(qt))
.context("QMatMul::from_arc on freshly quantized weight")?;
Ok(Self::Quant(qmm))
}
None => Ok(Self::Plain(Linear::new(weight, None))),
}
}
}
/// Above this M (the product of all input dims except the last)
/// dispatch the quantized matmul through `QMatMul::forward_via_f16`,
/// which dequantizes the weight to f16 once and runs cuBLAS GEMM.
/// At or below this M the GGUF GEMV kernel inside
/// `QMatMul::forward` wins (it operates on quantized blocks directly
/// and accumulates in registers).
///
/// Empirical: at M=30 on Qwen3.6-27B / RTX 5090, forward_via_f16 was
/// slightly *slower* than the GGUF GEMV kernel — the per-call dequant
/// cost (~30 MB f16 written to global memory per linear × ~480 calls
/// per prefill) eats the cuBLAS GEMM speedup at small M. The
/// crossover where the GEMM scaling actually beats the fixed dequant
/// tax sits well above M=8.
///
/// 64 is a conservative crossover that keeps short-prompt prefills
/// on the GGUF kernel (where the per-call cost is comparable to the
/// f16 path but the dequant tax is zero) and only activates the
/// dequant-then-GEMM path for long prefills where the GEMM size
/// makes amortising worth it. A proper fix is either a dequant
/// cache or a fused dequant+gemm cuda kernel — both larger projects.
const QUANT_PREFILL_M_THRESHOLD: usize = 64;
impl Module for MaybeQuantLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
match self {
Self::Plain(l) => l.forward(x),
Self::Quant(qm) => {
// Decode vs prefill split. `M` is the "rows of x" the
// matmul will iterate over — every dim except the last
// (which is in_features). For decode (`seq_len == 1`
// with batch 1) M is 1; for prefill with L>>1 it's L
// (or B*L).
let dims = x.dims();
let m: usize = dims.iter().take(dims.len() - 1).product();
if m > QUANT_PREFILL_M_THRESHOLD {
// Prefill: dequantize the weight once into f16,
// then run a real cuBLAS-backed GEMM. The cost of
// the dequant is amortised across all M tokens.
// `forward_via_f16` handles the dtype round-trip
// internally (output matches input dtype).
return qm.forward_via_f16(x);
}
// Decode (M <= threshold): use the on-the-fly GGUF
// GEMV kernel via `QMatMul::forward`. That kernel
// requires f32 inputs (it accumulates in f32 from the
// dequantized quant blocks); cast in/out at the
// boundary.
let in_dtype = x.dtype();
let x_f32 = if in_dtype == candle_core::DType::F32 {
x.clone()
} else {
x.to_dtype(candle_core::DType::F32)?
};
let y = qm.forward(&x_f32)?;
if y.dtype() == in_dtype {
Ok(y)
} else {
y.to_dtype(in_dtype)
}
}
}
}
}
/// Helper to build a [`Shard`] hint for a given dimension.
pub(crate) fn shard(dim: usize, rank: u32, world_size: u32) -> Shard {
Shard {
dim,
rank: rank as usize,
world_size: world_size as usize,
}
}
/// Output-dim sharded linear (column-parallel). Holds a
/// [`MaybeQuantLinear`] whose underlying weight is this rank's slice
/// of the full `[out_features, in_features]` tensor along dim 0.
pub struct ColumnParallelLinear {
inner: MaybeQuantLinear,
}
impl ColumnParallelLinear {
/// Load this rank's column-parallel slice from a
/// `ShardedVarBuilder`. The provided `vb` must already be `pp`-ed
/// to the layer's path (e.g. `vb.pp("model.layers.0.self_attn.q_proj")`).
///
/// Backward-compatible variant — no in-situ quantization. For
/// quantized loads, use [`Self::load_with_quant`].
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, None)
}
/// Like [`Self::load`] but quantizes the per-rank weight in-situ
/// when `quant` is `Some(dtype)`. Saves ~3-5x vs bf16/f16.
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(0, rank, world_size))
.with_context(|| format!("load column-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap column-parallel '{}'", vb.prefix()))?;
Ok(Self { inner })
}
}
impl Module for ColumnParallelLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
self.inner.forward(x)
}
}
/// Input-dim sharded linear (row-parallel).
///
/// Holds a sharded [`MaybeQuantLinear`] plus an `AllReduce` op the
/// forward chains after the local matmul to recover the full activation.
pub struct RowParallelLinear {
inner: MaybeQuantLinear,
#[cfg(feature = "cuda")]
all_reduce: AllReduce,
/// Whether the AllReduce should run. Column-parallel ↔ row-parallel
/// is a pair: the column output is sharded, the row input is
/// sharded, and the AllReduce gives back the full output. For
/// `world_size = 1` the AllReduce is a no-op so we skip it.
needs_reduce: bool,
}
impl RowParallelLinear {
/// Load this rank's row-parallel slice from a `ShardedVarBuilder`.
///
/// Under `cuda`, `comm` is the NCCL communicator the row-parallel
/// `AllReduce` runs against. On CPU builds the parameter is
/// elided — forward returns the partial sum, which is the *wrong*
/// answer for inference but lets us compile-check the model.
///
/// Backward-compatible variant — no in-situ quantization. For
/// quantized loads, use [`Self::load_with_quant`].
#[cfg(feature = "cuda")]
pub fn load(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: std::sync::Arc<cudarc::nccl::Comm>,
) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, comm, None)
}
/// Like [`Self::load`] but quantizes the per-rank weight in-situ.
#[cfg(feature = "cuda")]
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: std::sync::Arc<cudarc::nccl::Comm>,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
Ok(Self {
inner,
all_reduce: AllReduce::new(comm),
needs_reduce: world_size > 1,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
Self::load_with_quant(vb, rank, world_size, None)
}
#[cfg(not(feature = "cuda"))]
pub fn load_with_quant(
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
quant: Option<GgmlDType>,
) -> Result<Self> {
let weight = vb
.get_with_hints((), "weight", shard(1, rank, world_size))
.with_context(|| format!("load row-parallel '{}' weight", vb.prefix()))?;
let inner = MaybeQuantLinear::from_weight(weight, quant)
.with_context(|| format!("wrap row-parallel '{}'", vb.prefix()))?;
Ok(Self {
inner,
needs_reduce: world_size > 1,
})
}
}
impl Module for RowParallelLinear {
/// Local matmul followed by an `AllReduce` (when `cuda` and
/// `world_size > 1`). On CPU or single-rank, returns the partial
/// output directly — which is *only* correct for `world_size == 1`.
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let local = self.inner.forward(x)?;
#[cfg(feature = "cuda")]
if self.needs_reduce {
return local.apply_op1_no_bwd(&self.all_reduce);
}
let _ = self.needs_reduce;
Ok(local)
}
}

View File

@@ -1,678 +0,0 @@
//! Tensor-parallel Qwen3 dense model.
//!
//! Mirrors `candle_transformers::models::qwen3` structurally, but with:
//!
//! - Attention's `q_proj` / `k_proj` / `v_proj` as
//! [`ColumnParallelLinear`] (output sharded along the head dimension —
//! per-rank `num_heads = total/world_size`, ditto for kv heads).
//! - Attention's `o_proj` as [`RowParallelLinear`] (input sharded; the
//! trailing `AllReduce` recovers the full activation).
//! - MLP's `gate_proj` / `up_proj` as [`ColumnParallelLinear`] (sharded
//! along `intermediate_size`).
//! - MLP's `down_proj` as [`RowParallelLinear`].
//! - `embed_tokens`, all `RmsNorm`s, and `lm_head` **replicated** on
//! every rank. The per-rank duplicate weight is bounded and lets us
//! skip the embedding all-gather and the lm-head column-shard +
//! all-gather; both are pure latency optimisations that don't change
//! correctness.
//! - `kv_cache` holds the per-rank slice of K/V already (because they
//! came out of a column-parallel projection). No cache resharding
//! needed across ranks.
//!
//! Divisibility requirement, checked at load time:
//!
//! - `num_attention_heads % world_size == 0`
//! - `num_key_value_heads % world_size == 0`
//! - `intermediate_size % world_size == 0`
//!
//! Anything else bails — the safetensors slice would lose data otherwise.
//! This is the same divisibility-bail pattern that landed in
//! `EricLBuehler/mistral.rs` PR #2054.
//!
//! Replicated tensors (norms, embedding, lm_head) are loaded by asking
//! the `ShardedVarBuilder` for the full tensor via `vb.get(shape, name)`
//! — which defaults to `Shard { world_size: 1 }` and falls through to
//! the unsharded backend path.
use anyhow::{Context, Result, bail};
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use candle_nn::{Activation, Embedding, Linear, RmsNorm, kv_cache::ConcatKvCache};
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;
#[cfg(feature = "cuda")]
use cudarc::nccl::Comm;
use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
pub use candle_transformers::models::qwen3::Config;
/// Replicated rotary-embedding lookup. Re-implementation of the
/// `pub(crate)` candle equivalent — we can't reach into the upstream
/// type, so the inv-freq / sin / cos construction lives here.
pub(crate) struct Qwen3RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl Qwen3RotaryEmbedding {
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
})
}
fn apply(
&self,
q: &Tensor,
k: &Tensor,
offset: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, seq_len, _) = q.dims4()?;
let cos = self.cos.narrow(0, offset, seq_len)?;
let sin = self.sin.narrow(0, offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
/// Helper: load a replicated tensor by asking the ShardedVarBuilder for
/// the full tensor (world_size=1 hint falls through to SimpleBackend).
fn load_replicated<S: Into<candle_core::Shape>>(
vb: &ShardedVarBuilder,
shape: S,
name: &str,
) -> Result<Tensor> {
vb.get(shape, name)
.with_context(|| format!("load replicated '{}/{name}'", vb.prefix()))
}
fn load_rms_norm(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<RmsNorm> {
let weight = load_replicated(vb, size, "weight")?;
Ok(RmsNorm::new(weight, eps))
}
/// TP MLP. SwiGLU = `down(silu(gate(x)) * up(x))`.
pub(crate) struct TpQwen3MLP {
gate_proj: ColumnParallelLinear,
up_proj: ColumnParallelLinear,
down_proj: RowParallelLinear,
act_fn: Activation,
}
impl TpQwen3MLP {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
bail!(
"intermediate_size {} not divisible by world_size {}",
cfg.intermediate_size,
world_size
);
}
Ok(Self {
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size, comm)?,
act_fn: cfg.hidden_act,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
if !cfg.intermediate_size.is_multiple_of(world_size as usize) {
bail!(
"intermediate_size {} not divisible by world_size {}",
cfg.intermediate_size,
world_size
);
}
Ok(Self {
gate_proj: ColumnParallelLinear::load(&vb.pp("gate_proj"), rank, world_size)?,
up_proj: ColumnParallelLinear::load(&vb.pp("up_proj"), rank, world_size)?,
down_proj: RowParallelLinear::load(&vb.pp("down_proj"), rank, world_size)?,
act_fn: cfg.hidden_act,
})
}
}
impl Module for TpQwen3MLP {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = x.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
/// TP attention. Carries per-rank head counts and the q/k per-head
/// RmsNorms (which are replicated and operate on a flattened B*H*L
/// axis, so the same code path works irrespective of how H was split).
pub(crate) struct TpQwen3Attention {
q_proj: ColumnParallelLinear,
k_proj: ColumnParallelLinear,
v_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
q_norm: RmsNorm,
k_norm: RmsNorm,
local_num_heads: usize,
local_num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
local_hidden_size: usize,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl TpQwen3Attention {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
Self::load_inner(
cfg,
rotary_emb,
vb,
rank,
world_size,
#[cfg(feature = "cuda")]
comm,
)
}
#[cfg(not(feature = "cuda"))]
pub fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
) -> Result<Self> {
Self::load_inner(cfg, rotary_emb, vb, rank, world_size)
}
fn load_inner(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
#[cfg(feature = "cuda")] comm: Arc<Comm>,
) -> Result<Self> {
if cfg.use_sliding_window {
bail!("sliding window is not yet supported in the TP path");
}
if cfg.attention_bias {
bail!("attention_bias=true is not supported by ColumnParallel/RowParallelLinear yet");
}
let ws = world_size as usize;
if !cfg.num_attention_heads.is_multiple_of(ws) {
bail!(
"num_attention_heads {} not divisible by world_size {}",
cfg.num_attention_heads,
world_size
);
}
if !cfg.num_key_value_heads.is_multiple_of(ws) {
bail!(
"num_key_value_heads {} not divisible by world_size {}",
cfg.num_key_value_heads,
world_size
);
}
let head_dim = cfg.head_dim;
let local_num_heads = cfg.num_attention_heads / ws;
let local_num_kv_heads = cfg.num_key_value_heads / ws;
let num_kv_groups = local_num_heads / local_num_kv_heads;
let local_hidden_size = head_dim * local_num_heads;
let q_proj = ColumnParallelLinear::load(&vb.pp("q_proj"), rank, world_size)?;
let k_proj = ColumnParallelLinear::load(&vb.pp("k_proj"), rank, world_size)?;
let v_proj = ColumnParallelLinear::load(&vb.pp("v_proj"), rank, world_size)?;
#[cfg(feature = "cuda")]
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size, comm)?;
#[cfg(not(feature = "cuda"))]
let o_proj = RowParallelLinear::load(&vb.pp("o_proj"), rank, world_size)?;
let q_norm = load_rms_norm(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
let k_norm = load_rms_norm(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
// dim=2 because we cat along the seq axis of (B, H, L, D) tensors.
let kv_cache = ConcatKvCache::new(2);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
local_num_heads,
local_num_kv_heads,
num_kv_groups,
head_dim,
local_hidden_size,
rotary_emb,
kv_cache,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
// 1. Projections (column-parallel → output is sharded).
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
// 2. Reshape: (B, L, H, D) → (B, H, L, D).
let q = q
.reshape((b, l, self.local_num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b, l, self.local_num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
// 3. Per-head RmsNorm (replicated weight, flat input).
let q_flat = q.flatten(0, 2)?;
let k_flat = k.flatten(0, 2)?;
let q_flat = self.q_norm.forward(&q_flat)?;
let k_flat = self.k_norm.forward(&k_flat)?;
let q = q_flat.reshape((b, self.local_num_heads, l, self.head_dim))?;
let k = k_flat.reshape((b, self.local_num_kv_heads, l, self.head_dim))?;
// 4. Rotary.
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
// 5. Accumulate KV.
let (k, v) = self.kv_cache.append(&k, &v)?;
// 6. GQA repeat_kv on the rank-local K/V.
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
// 7. Attention scores.
let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?;
// 8. Output projection (row-parallel → AllReduce inside).
ctx.transpose(1, 2)?
.reshape((b, l, self.local_hidden_size))?
.apply(&self.o_proj)
}
pub fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
}
struct TpDecoderLayer {
self_attn: TpQwen3Attention,
mlp: TpQwen3MLP,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl TpDecoderLayer {
#[cfg(feature = "cuda")]
fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
let self_attn = TpQwen3Attention::load(
cfg,
rotary_emb,
&vb.pp("self_attn"),
rank,
world_size,
comm.clone(),
)?;
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size, comm)?;
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let ln2 = load_rms_norm(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
#[cfg(not(feature = "cuda"))]
fn load(
cfg: &Config,
rotary_emb: Arc<Qwen3RotaryEmbedding>,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
) -> Result<Self> {
let self_attn =
TpQwen3Attention::load(cfg, rotary_emb, &vb.pp("self_attn"), rank, world_size)?;
let mlp = TpQwen3MLP::load(cfg, &vb.pp("mlp"), rank, world_size)?;
let ln1 = load_rms_norm(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let ln2 = load_rms_norm(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
self_attn,
mlp,
ln1,
ln2,
})
}
fn forward(
&mut self,
x: &Tensor,
mask: Option<&Tensor>,
offset: usize,
) -> candle_core::Result<Tensor> {
let h = self.ln1.forward(x)?;
let h = self.self_attn.forward(&h, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
let h2 = h2.apply(&self.mlp)?;
x + h2
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
}
/// Base TP Qwen3 transformer — embedding, decoder stack, final norm.
/// The lm_head sits on top in [`TpQwen3ForCausalLM`].
pub struct TpQwen3Model {
embed_tokens: Embedding,
layers: Vec<TpDecoderLayer>,
norm: RmsNorm,
device: Device,
dtype: DType,
}
impl TpQwen3Model {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
let embed_vb = vb.pp("model.embed_tokens");
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let vb_l = vb.pp("model.layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(TpDecoderLayer::load(
cfg,
rotary.clone(),
&vb_l.pp(i),
rank,
world_size,
comm.clone(),
)?);
}
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
#[cfg(not(feature = "cuda"))]
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
let rotary = Arc::new(Qwen3RotaryEmbedding::new(dtype, cfg, &device)?);
let embed_vb = vb.pp("model.embed_tokens");
let embed_weight = load_replicated(&embed_vb, (cfg.vocab_size, cfg.hidden_size), "weight")?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let vb_l = vb.pp("model.layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(TpDecoderLayer::load(
cfg,
rotary.clone(),
&vb_l.pp(i),
rank,
world_size,
)?);
}
let norm = load_rms_norm(&vb.pp("model.norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
device,
dtype,
})
}
pub fn embed_weight(&self) -> &Tensor {
self.embed_tokens.embeddings()
}
pub fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
/// TP Qwen3 with a (replicated) language-model head on top.
pub struct TpQwen3ForCausalLM {
base: TpQwen3Model,
lm_head: Linear,
}
impl TpQwen3ForCausalLM {
#[cfg(feature = "cuda")]
pub fn load(
cfg: &Config,
vb: &ShardedVarBuilder,
rank: u32,
world_size: u32,
comm: Arc<Comm>,
) -> Result<Self> {
let base = TpQwen3Model::load(cfg, vb, rank, world_size, comm)?;
let lm_head = build_lm_head(cfg, vb, &base)?;
let model = Self { base, lm_head };
log_construction_complete(cfg, rank, world_size, model.device());
Ok(model)
}
#[cfg(not(feature = "cuda"))]
pub fn load(cfg: &Config, vb: &ShardedVarBuilder, rank: u32, world_size: u32) -> Result<Self> {
let base = TpQwen3Model::load(cfg, vb, rank, world_size)?;
let lm_head = build_lm_head(cfg, vb, &base)?;
let model = Self { base, lm_head };
log_construction_complete(cfg, rank, world_size, model.device());
Ok(model)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self.base.forward(input, offset)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
pub fn device(&self) -> &Device {
&self.base.device
}
pub fn dtype(&self) -> DType {
self.base.dtype
}
}
fn build_lm_head(cfg: &Config, vb: &ShardedVarBuilder, base: &TpQwen3Model) -> Result<Linear> {
if cfg.tie_word_embeddings {
Ok(Linear::new(base.embed_weight().clone(), None))
} else {
let weight = load_replicated(
&vb.pp("lm_head"),
(cfg.vocab_size, cfg.hidden_size),
"weight",
)?;
Ok(Linear::new(weight, None))
}
}
/// VRAM accounting + config dump emitted at the end of
/// `TpQwen3ForCausalLM::load`. Same intent as the Qwen3-Next variant
/// in tp_qwen3_5.rs — surface the resolved hyperparameters and
/// per-rank free VRAM in one line so an operator chasing an OOM or a
/// numerical issue doesn't have to grep the per-layer load logs.
#[cfg(feature = "cuda")]
fn log_construction_complete(cfg: &Config, rank: u32, world_size: u32, device: &Device) {
use candle_core::cuda::cudarc::driver::result;
use candle_core::cuda_backend::WrapErr;
let (free_mb, total_mb) = if let Device::Cuda(dev) = device {
if dev.cuda_stream().context().bind_to_thread().w().is_ok() {
match result::mem_get_info() {
Ok((free, total)) => (free / (1024 * 1024), total / (1024 * 1024)),
Err(_) => (0, 0),
}
} else {
(0, 0)
}
} else {
(0, 0)
};
// Per-rank KV cache cost at one token: K + V × bf16. Vanilla
// Qwen3 is dense attention end-to-end, so every layer
// contributes. Knowing per-token bytes lets the operator estimate
// headroom for a given prompt length before hitting an edge.
let per_rank_num_kv_heads = (cfg.num_key_value_heads / world_size as usize).max(1);
let kv_bytes_per_token_per_layer = per_rank_num_kv_heads * cfg.head_dim * 2 * 2;
let kv_bytes_per_token = kv_bytes_per_token_per_layer * cfg.num_hidden_layers;
tracing::info!(
target: "neuron::tp::load",
rank,
world_size,
free_mb,
total_mb,
vocab_size = cfg.vocab_size,
hidden_size = cfg.hidden_size,
num_hidden_layers = cfg.num_hidden_layers,
num_attention_heads = cfg.num_attention_heads,
num_key_value_heads = cfg.num_key_value_heads,
head_dim = cfg.head_dim,
max_position_embeddings = cfg.max_position_embeddings,
per_rank_num_kv_heads,
kv_bytes_per_token,
"Qwen3 model construction complete"
);
}
#[cfg(not(feature = "cuda"))]
fn log_construction_complete(cfg: &Config, rank: u32, world_size: u32, _device: &Device) {
let per_rank_num_kv_heads = (cfg.num_key_value_heads / world_size as usize).max(1);
let kv_bytes_per_token_per_layer = per_rank_num_kv_heads * cfg.head_dim * 2 * 2;
let kv_bytes_per_token = kv_bytes_per_token_per_layer * cfg.num_hidden_layers;
tracing::info!(
target: "neuron::tp::load",
rank,
world_size,
vocab_size = cfg.vocab_size,
hidden_size = cfg.hidden_size,
num_hidden_layers = cfg.num_hidden_layers,
num_attention_heads = cfg.num_attention_heads,
num_key_value_heads = cfg.num_key_value_heads,
head_dim = cfg.head_dim,
max_position_embeddings = cfg.max_position_embeddings,
per_rank_num_kv_heads,
kv_bytes_per_token,
"Qwen3 model construction complete"
);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,502 +0,0 @@
//! Entry point for `neuron --worker`.
//!
//! The worker reads one newline-delimited JSON `WorkerRequest` from
//! stdin per loop iteration, dispatches synchronously, and writes
//! exactly one `WorkerResponse` JSON line to stdout. tracing goes to
//! stderr so it doesn't collide with the RPC stream.
//!
//! NCCL operations (`Init`, `NcclSanityCheck`) and model lifecycle ops
//! (`LoadDenseShard`, `GenerateStep`, `ClearKvCache`, `UnloadModel`)
//! are real when built with the `cuda` feature; without it they reply
//! with `Error{kind="cuda_feature_not_enabled"}` so the leader can tell
//! the difference between a misconfigured build and a genuine NCCL or
//! model failure.
use anyhow::Result;
use std::collections::HashMap;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use super::nccl_state::NcclState;
use super::rpc::{WorkerRequest, WorkerResponse};
#[cfg(feature = "cuda")]
use super::tp_qwen3::TpQwen3ForCausalLM;
#[cfg(feature = "cuda")]
use super::tp_qwen3_5::TpQwen3_5ForCausalLM;
/// Worker-side discriminator over the architectures we can load via
/// `LoadDenseShard`. Mirrors `super::TpLeaderModel` on the leader
/// side — the dispatch happens on the `model_type` extracted from the
/// config JSON.
#[cfg(feature = "cuda")]
enum WorkerModel {
Qwen3(TpQwen3ForCausalLM),
Qwen3_5(TpQwen3_5ForCausalLM),
}
#[cfg(feature = "cuda")]
impl WorkerModel {
fn forward(
&mut self,
input: &candle_core::Tensor,
offset: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
WorkerModel::Qwen3(m) => m.forward(input, offset),
WorkerModel::Qwen3_5(m) => m.forward(input, offset),
}
}
fn clear_kv_cache(&mut self) {
match self {
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
WorkerModel::Qwen3_5(m) => m.clear_kv_cache(),
}
}
fn device(&self) -> &candle_core::Device {
match self {
WorkerModel::Qwen3(m) => m.device(),
WorkerModel::Qwen3_5(m) => m.device(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct WorkerConfig {
pub rank: u32,
pub world_size: u32,
pub cuda_device: u32,
}
/// Drive the worker RPC loop until `Shutdown` or EOF on stdin.
pub async fn run(config: WorkerConfig) -> Result<()> {
tracing::info!(
rank = config.rank,
world_size = config.world_size,
cuda_device = config.cuda_device,
"tp worker starting"
);
let mut state = WorkerState::new(config);
let stdin = tokio::io::stdin();
let mut reader = BufReader::new(stdin).lines();
let mut stdout = tokio::io::stdout();
while let Some(line) = reader.next_line().await? {
if line.trim().is_empty() {
continue;
}
let req: WorkerRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
let resp = WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse {line:?}: {e}"),
};
write_response(&mut stdout, &resp).await?;
continue;
}
};
let resp = state.handle(req).await;
let is_bye = matches!(resp, WorkerResponse::Bye);
write_response(&mut stdout, &resp).await?;
if is_bye {
break;
}
}
tracing::info!(rank = config.rank, "tp worker exiting");
Ok(())
}
async fn write_response(stdout: &mut tokio::io::Stdout, resp: &WorkerResponse) -> Result<()> {
let mut line = serde_json::to_string(resp)?;
line.push('\n');
stdout.write_all(line.as_bytes()).await?;
stdout.flush().await?;
Ok(())
}
/// One rank's local state. Owns the rank's NCCL communicator (via
/// `NcclState`) and the rank's shard of every loaded model.
struct WorkerState {
config: WorkerConfig,
nccl: NcclState,
/// Loaded model shards keyed by `model_id`. Each entry wraps the
/// rank's TP architecture handle (Qwen3 or Qwen3-Next) — the
/// column/row-parallel layers hold an `Arc<Comm>` cloned from
/// `nccl`. Cuda-only: the underlying types reference cudarc types
/// that don't exist without the cuda feature.
#[cfg(feature = "cuda")]
models: HashMap<String, WorkerModel>,
/// Placeholder so the non-cuda build keeps the same field name set
/// and `WorkerState::new` reads the same on both.
#[cfg(not(feature = "cuda"))]
#[allow(dead_code)]
models: HashMap<String, ()>,
}
impl WorkerState {
fn new(config: WorkerConfig) -> Self {
Self {
config,
nccl: NcclState::new(),
models: HashMap::new(),
}
}
async fn handle(&mut self, req: WorkerRequest) -> WorkerResponse {
match req {
WorkerRequest::Ping => WorkerResponse::Pong {
rank: self.config.rank,
world_size: self.config.world_size,
cuda_device: self.config.cuda_device,
},
WorkerRequest::Init { comm_id } => self.nccl.init(self.config, &comm_id),
WorkerRequest::NcclSanityCheck => self.nccl.sanity_check(),
WorkerRequest::LoadDenseShard {
model_id,
config_json,
safetensors_paths,
quant,
} => self.handle_load_dense_shard(model_id, config_json, safetensors_paths, quant),
WorkerRequest::GenerateStep {
model_id,
tokens,
offset,
} => self.handle_generate_step(&model_id, tokens, offset),
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
WorkerRequest::Shutdown => WorkerResponse::Bye,
}
}
#[cfg(feature = "cuda")]
fn handle_load_dense_shard(
&mut self,
model_id: String,
config_json: String,
safetensors_paths: Vec<String>,
quant: Option<String>,
) -> WorkerResponse {
use crate::harness::arch::qwen3_5 as qwen3_5_arch;
use candle_core::{DType, Device};
use candle_nn::var_builder::ShardedSafeTensors;
use candle_transformers::models::qwen3 as qwen3_dense;
use std::path::PathBuf;
let quant_dtype = match parse_quant_string(quant.as_deref()) {
Ok(q) => q,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse quant: {e}"),
};
}
};
if self.models.contains_key(&model_id) {
return WorkerResponse::Error {
kind: "already_loaded".into(),
message: format!("model '{model_id}' already loaded on this rank"),
};
}
let comm = match self.nccl.comm() {
Some(c) => c,
None => {
return WorkerResponse::Error {
kind: "nccl_not_initialised".into(),
message: "LoadDenseShard requires Init to have completed first".into(),
};
}
};
// Peek at model_type so we know which architecture to build.
let model_type = serde_json::from_str::<serde_json::Value>(&config_json)
.ok()
.as_ref()
.and_then(|v| v.get("model_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let device = match Device::new_cuda(self.config.cuda_device as usize) {
Ok(d) => d,
Err(e) => {
return WorkerResponse::Error {
kind: "cuda_unavailable".into(),
message: format!("Device::new_cuda({}) failed: {e}", self.config.cuda_device),
};
}
};
let paths: Vec<PathBuf> = safetensors_paths.into_iter().map(PathBuf::from).collect();
// SAFETY: same invariant as the single-GPU dense path — the HF
// cache files are treated as immutable while the mmap is held.
let vb = match unsafe { ShardedSafeTensors::var_builder(&paths, DType::BF16, &device) } {
Ok(v) => v,
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("ShardedSafeTensors::var_builder: {e}"),
};
}
};
// Separate mmap of the same paths for the direct fused-region
// loader in `fused_load`. Linux's page cache shares the
// underlying pages between the two mmaps; the cost is one
// extra set of safetensors-header parses.
let mmap = match unsafe { candle_core::safetensors::MmapedSafetensors::multi(&paths) } {
Ok(m) => m,
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("MmapedSafetensors::multi: {e}"),
};
}
};
let loaded = match model_type.as_str() {
"qwen3" => {
let cfg: qwen3_dense::Config = match serde_json::from_str(&config_json) {
Ok(c) => c,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse Qwen3 Config JSON: {e}"),
};
}
};
match TpQwen3ForCausalLM::load(
&cfg,
&vb,
self.config.rank,
self.config.world_size,
comm,
) {
Ok(m) => WorkerModel::Qwen3(m),
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("TpQwen3ForCausalLM::load: {e:#}"),
};
}
}
}
"qwen3_5" => {
let cfg: qwen3_5_arch::Config = match serde_json::from_str(&config_json) {
Ok(c) => c,
Err(e) => {
return WorkerResponse::Error {
kind: "bad_request".into(),
message: format!("parse Qwen3-Next Config JSON: {e}"),
};
}
};
match TpQwen3_5ForCausalLM::load(
cfg,
&vb,
&mmap,
self.config.rank,
self.config.world_size,
comm,
quant_dtype,
) {
Ok(m) => WorkerModel::Qwen3_5(m),
Err(e) => {
return WorkerResponse::Error {
kind: "load_failed".into(),
message: format!("TpQwen3_5ForCausalLM::load: {e:#}"),
};
}
}
}
other => {
return WorkerResponse::Error {
kind: "unsupported_arch".into(),
message: format!(
"worker: unsupported model_type '{other}' (supported: qwen3, qwen3_5)"
),
};
}
};
self.models.insert(model_id.clone(), loaded);
tracing::info!(
rank = self.config.rank,
model = %model_id,
model_type = %model_type,
"loaded TP shard"
);
WorkerResponse::LoadDenseShardOk
}
#[cfg(not(feature = "cuda"))]
fn handle_load_dense_shard(
&mut self,
_model_id: String,
_config_json: String,
_safetensors_paths: Vec<String>,
_quant: Option<String>,
) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "LoadDenseShard requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_generate_step(
&mut self,
model_id: &str,
tokens: Vec<u32>,
offset: usize,
) -> WorkerResponse {
use candle_core::Tensor;
let Some(model) = self.models.get_mut(model_id) else {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
};
let device = model.device().clone();
let input = match Tensor::new(tokens.as_slice(), &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => {
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("build input tensor: {e}"),
};
}
};
let start = std::time::Instant::now();
tracing::debug!(
rank = self.config.rank,
model = %model_id,
tokens = tokens.len(),
offset,
"worker GenerateStep: forward starting"
);
// Drop the resulting logits — the leader uses its own copy from
// rank 0. The forward's value here is the NCCL collectives it
// issues, which let the leader's rank-0 forward make progress.
if let Err(e) = model.forward(&input, offset) {
tracing::warn!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
error = %e,
"worker GenerateStep: forward failed"
);
return WorkerResponse::Error {
kind: "forward_failed".into(),
message: format!("TP forward: {e}"),
};
}
tracing::debug!(
rank = self.config.rank,
model = %model_id,
elapsed_ms = start.elapsed().as_millis(),
"worker GenerateStep: forward done"
);
WorkerResponse::GenerateStepOk
}
#[cfg(not(feature = "cuda"))]
fn handle_generate_step(
&mut self,
_model_id: &str,
_tokens: Vec<u32>,
_offset: usize,
) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "GenerateStep requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
let Some(model) = self.models.get_mut(model_id) else {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
};
model.clear_kv_cache();
WorkerResponse::KvCacheCleared
}
#[cfg(not(feature = "cuda"))]
fn handle_clear_kv_cache(&mut self, _model_id: &str) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "ClearKvCache requires --features cuda".into(),
}
}
#[cfg(feature = "cuda")]
fn handle_unload_model(&mut self, model_id: &str) -> WorkerResponse {
if self.models.remove(model_id).is_none() {
return WorkerResponse::Error {
kind: "model_not_loaded".into(),
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
};
}
tracing::info!(rank = self.config.rank, model = %model_id, "unloaded TP shard");
WorkerResponse::Unloaded
}
#[cfg(not(feature = "cuda"))]
fn handle_unload_model(&mut self, _model_id: &str) -> WorkerResponse {
WorkerResponse::Error {
kind: "cuda_feature_not_enabled".into(),
message: "UnloadModel requires --features cuda".into(),
}
}
}
/// Parse a `ModelSpec.quant` string into a `GgmlDType`. Accepts the
/// common ggml format names (case-insensitive). `None` and `Some("")`
/// both map to "no quantization".
///
/// Supported: `q4_0`, `q4_1`, `q5_0`, `q5_1`, `q8_0`, `q8_1`,
/// `q2k`/`q2_k`, `q3k`/`q3_k`, `q4k`/`q4_k`, `q5k`/`q5_k`,
/// `q6k`/`q6_k`, `q8k`/`q8_k`, `f16`, `bf16`, `f32`. The underscore
/// is optional and the prefix is case-insensitive.
#[cfg(feature = "cuda")]
pub(crate) fn parse_quant_string(
s: Option<&str>,
) -> anyhow::Result<Option<candle_core::quantized::GgmlDType>> {
use candle_core::quantized::GgmlDType;
let s = match s {
Some(s) if !s.is_empty() => s,
_ => return Ok(None),
};
let normalised = s.to_ascii_lowercase().replace('_', "");
let dtype = match normalised.as_str() {
"q40" => GgmlDType::Q4_0,
"q41" => GgmlDType::Q4_1,
"q50" => GgmlDType::Q5_0,
"q51" => GgmlDType::Q5_1,
"q80" => GgmlDType::Q8_0,
"q81" => GgmlDType::Q8_1,
"q2k" => GgmlDType::Q2K,
"q3k" => GgmlDType::Q3K,
"q4k" | "q4km" => GgmlDType::Q4K,
"q5k" | "q5km" => GgmlDType::Q5K,
"q6k" => GgmlDType::Q6K,
"q8k" => GgmlDType::Q8K,
"f16" => GgmlDType::F16,
"bf16" => GgmlDType::BF16,
"f32" => GgmlDType::F32,
other => anyhow::bail!(
"unknown quant '{other}' (expected one of: q4_0, q4_1, q5_0, q5_1, q8_0, \
q8_1, q2k, q3k, q4k, q5k, q6k, q8k, f16, bf16, f32)"
),
};
Ok(Some(dtype))
}

View File

@@ -24,12 +24,6 @@ impl HealthCache {
inner: RwLock::new(HealthResponse {
uptime_secs: 0,
devices: vec![],
// The cache only owns the device-state half of /health;
// the api handler overlays activation from the tracker.
// Initialise with the default (Ready, empty lists) so a
// direct read from the cache stays a well-typed
// HealthResponse on the wire.
activation: Default::default(),
}),
has_gpus: RwLock::new(false),
}

View File

@@ -1,9 +1,5 @@
pub mod activation;
pub mod api;
pub mod config;
pub mod cuda;
pub mod discovery;
pub mod harness;
pub mod health;
pub mod startup;
pub mod wire;

View File

@@ -1,66 +1,21 @@
use anyhow::{Context, Result};
use anyhow::Result;
use clap::Parser;
use neuron::{
activation, api,
config::NeuronConfig,
discovery,
harness::{HarnessRegistry, tp},
health, startup,
};
use neuron::{api, config::NeuronConfig, discovery, harness::HarnessRegistry, health};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing_subscriber::EnvFilter;
/// Top-level CLI. The same binary runs as either the public neuron
/// daemon (default), a tensor-parallel worker subprocess (when
/// `--worker` is set, spawned by the leader on the same host), or a
/// one-shot TP NCCL handshake check (when `--tp-smoke` is set).
#[derive(Parser)]
#[command(name = "neuron")]
#[command(about = "Per-node daemon for cortex inference clusters")]
#[command(version)]
struct Args {
/// Run in tensor-parallel worker mode. The leader process spawns
/// one of these per non-zero NCCL rank and drives it over
/// newline-delimited JSON on stdin/stdout. Worker mode skips
/// discovery, the HTTP listener, and the health poller — it's a
/// pure RPC loop.
#[arg(long, default_value_t = false)]
worker: bool,
/// Run a one-shot TP smoke test: spawn `--tp-size - 1` worker
/// subprocesses on `--cuda-devices`, build the NCCL communicator,
/// run an `AllReduce` sanity check across every rank, and exit.
/// Used to validate the TP plumbing in isolation from model load
/// and inference. Diagnostic-only — not exposed through the daemon
/// HTTP API.
#[arg(long, default_value_t = false)]
tp_smoke: bool,
/// NCCL rank for worker mode. Ignored when `--worker` is not set.
#[arg(long, default_value_t = 0)]
rank: u32,
/// Total NCCL world size for worker mode or TP smoke mode.
#[arg(long, default_value_t = 1)]
tp_size: u32,
/// CUDA device index for worker mode. Ignored when `--worker` is
/// not set.
#[arg(long, default_value_t = 0)]
cuda_device: u32,
/// Comma-separated CUDA device indices for TP smoke mode (one per
/// rank, starting with rank 0). Must have `tp_size` entries.
#[arg(long, value_delimiter = ',')]
cuda_devices: Vec<u32>,
/// Port to listen on (overrides config file). Daemon mode only.
/// Port to listen on (overrides config file).
#[arg(short, long)]
port: Option<u16>,
/// Path to the neuron config file. Daemon mode only.
/// Path to the neuron config file.
#[arg(short, long, default_value = "neuron.toml")]
config: String,
}
@@ -68,99 +23,20 @@ struct Args {
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info,neuron=debug")),
)
.init();
let args = Args::parse();
if args.worker {
return tp::worker::run(tp::worker::WorkerConfig {
rank: args.rank,
world_size: args.tp_size,
cuda_device: args.cuda_device,
})
.await;
}
if args.tp_smoke {
return tp_smoke(args.tp_size, args.cuda_devices).await;
}
daemon(args).await
}
/// One-shot tensor-parallel handshake. Spawns N-1 worker subprocesses
/// (rank 0 stays in this process), builds the NCCL communicator across
/// the full world, runs an AllReduce sanity check, and shuts everyone
/// down. Output is plain log lines on stderr + a final summary on
/// stdout in `key=value` form so an outer script can parse it.
async fn tp_smoke(tp_size: u32, cuda_devices: Vec<u32>) -> Result<()> {
if tp_size < 2 {
anyhow::bail!("--tp-size must be at least 2 (got {tp_size})");
}
if cuda_devices.len() as u32 != tp_size {
anyhow::bail!(
"--cuda-devices must list exactly {tp_size} entries (got {})",
cuda_devices.len()
);
}
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
let leader_device = cuda_devices[0];
tracing::info!(
tp_size,
?cuda_devices,
binary = %exe.display(),
"tp-smoke: spawning worker pool"
);
// tp_smoke is a diagnostic tool; spawn the leader's device worker
// directly. (In the daemon path, CandleHarness::ensure_device_worker
// caches one per device.)
let leader_worker = neuron::harness::device_worker::DeviceWorkerHandle::spawn(leader_device)
.context("spawn leader device worker for tp-smoke")?;
let mut pool =
tp::WorkerPool::spawn(&exe, tp_size, &cuda_devices, leader_worker.clone()).await?;
tracing::info!("tp-smoke: pinging every worker");
let pongs = pool.ping_all().await?;
for p in &pongs {
tracing::info!(?p, "tp-smoke: pong");
}
tracing::info!(leader_device, "tp-smoke: initialising NCCL");
pool.init_nccl(leader_device).await?;
tracing::info!("tp-smoke: running AllReduce sanity check");
pool.nccl_sanity_check().await?;
tracing::info!("tp-smoke: shutting down pool");
pool.shutdown().await?;
println!("status=ok");
println!("tp_size={tp_size}");
println!(
"cuda_devices={}",
cuda_devices
.iter()
.map(|d| d.to_string())
.collect::<Vec<_>>()
.join(",")
);
Ok(())
}
async fn daemon(args: Args) -> Result<()> {
let cfg = NeuronConfig::load(&args.config).unwrap_or_else(|e| {
tracing::warn!(path = %args.config, error = %e, "config not found, using defaults");
NeuronConfig::default()
});
let port = args.port.unwrap_or(cfg.port);
let bind_url = format!("http://localhost:{port}");
let start_time = Instant::now();
tracing::info!("running hardware discovery");
@@ -171,12 +47,9 @@ async fn daemon(args: Args) -> Result<()> {
"discovery complete"
);
// Build harness registry from config. In-process harnesses (candle)
// need to know neuron's own bind URL so they can return it from
// inference_endpoint.
let registry = HarnessRegistry::from_configs(&cfg.harnesses, &bind_url, &cfg.harness);
// Build harness registry from config.
let registry = HarnessRegistry::from_configs(&cfg.harnesses);
discovery_result.harnesses = registry.names();
let candle = registry.candle();
let health_cache = Arc::new(health::HealthCache::new());
health_cache
@@ -188,64 +61,17 @@ async fn daemon(args: Args) -> Result<()> {
poller_cache.poll_loop(start_time).await;
});
// Track pre-warm progress so `/health` can tell callers whether
// configured default_models are still loading. Primed with the
// pending list now; the spawned task below flips entries through
// in_progress → completed/failed and finally toggles state=ready.
let activation = Arc::new(activation::ActivationTracker::new(&cfg.default_models));
let state = Arc::new(api::NeuronState {
discovery: discovery_result,
health_cache,
registry: RwLock::new(registry),
candle,
activation: Arc::clone(&activation),
});
// Bind the HTTP listener BEFORE kicking off default_models loading.
// Previously load_default_models ran synchronously on this task,
// which delayed the bind by minutes for big TP models and made the
// host look down to anything probing `/health` during pre-warm.
// The pre-warm task runs in the background instead — `/health`
// surfaces its progress via the activation field.
let app = api::neuron_routes().with_state(Arc::clone(&state));
let app = api::neuron_routes().with_state(state);
let addr: std::net::SocketAddr = format!("0.0.0.0:{port}").parse()?;
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!("neuron listening on {addr}");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
if !cfg.default_models.is_empty() {
let state_for_prewarm = Arc::clone(&state);
let default_models = cfg.default_models.clone();
tokio::spawn(async move {
// Read lock held for the whole pre-warm run. The unload
// path takes the same read lock per call (no writers) and
// serialises through the candle harness's own internal
// mutex, so concurrent on-demand loads and pre-warm loads
// do not race on the same model.
let registry = state_for_prewarm.registry.read().await;
startup::load_default_models(&registry, &default_models, &state_for_prewarm.activation)
.await;
});
}
axum::serve(listener, app)
.with_graceful_shutdown(startup::shutdown_signal())
.await?;
// Deactivation: serve has returned (graceful shutdown signal
// received and connections drained). Release CUDA contexts / VRAM
// by unloading every model before exiting; systemd's TimeoutStopSec
// bounds how long this phase may take.
let registry = state.registry.read().await;
startup::unload_all_models(&registry).await;
tracing::info!("shutdown complete");
// Fast-exit instead of returning. Returning lets `#[tokio::main]`
// drop the runtime, which in turn waits on the blocking thread
// pool to drain. After a CUDA driver error (OOM → illegal address)
// a spawn_blocking thread can be wedged inside `cuCtxGetCurrent`,
// and tokio's drain has no timeout. systemd then SIGABRTs us and
// dumps core. Skipping the drain hands the OS a clean exit code;
// the OS reaps the stuck threads. See the 2026-05-26 incident
// captured under "Stack trace of thread 2951308" in the journal.
std::process::exit(0);
Ok(())
}

View File

@@ -1,176 +0,0 @@
//! Activation- and deactivation-time orchestration.
//!
//! Wired from `main.rs` around the HTTP listener — activation runs
//! before bind, deactivation runs after axum returns from its
//! graceful-shutdown future. Kept in its own module so the logic is
//! unit-testable without spinning up a full neuron process.
use crate::activation::ActivationTracker;
use crate::harness::HarnessRegistry;
use crate::harness::preflight::PreflightError;
use cortex_core::harness::ModelSpec;
use std::time::{Duration, Instant};
use tokio::signal;
/// Maximum time we wait on a single `unload_model` call during
/// shutdown. The TP unload path tries `Arc::try_unwrap`, which fails
/// fast when an inference is in flight, so a healthy unload returns
/// in milliseconds. The timeout exists to bound a *future* unload
/// path that might genuinely block on a stuck worker, so a single
/// wedged model can't burn the whole systemd TimeoutStopSec window.
const UNLOAD_TIMEOUT: Duration = Duration::from_secs(20);
/// Load each spec sequentially against the registry, treating
/// individual failures as warnings rather than fatal errors.
///
/// VRAM contention makes parallel loads risky; the sequential path is
/// boring but correct. The function logs elapsed time per load and
/// updates `activation` so the `/health` endpoint can tell callers
/// which models are still pre-warming. Caller is expected to run this
/// in a background `tokio::spawn` task — the HTTP listener binds
/// independently so the host is reachable during the pre-warm window.
pub async fn load_default_models(
registry: &HarnessRegistry,
specs: &[ModelSpec],
activation: &ActivationTracker,
) {
if specs.is_empty() {
activation.mark_ready().await;
return;
}
tracing::info!(count = specs.len(), "loading default models");
for spec in specs {
let start = Instant::now();
activation.start_loading(&spec.model_id).await;
match registry.load_model(spec).await {
Ok(()) => {
activation.complete_loading(&spec.model_id).await;
tracing::info!(
model = %spec.model_id,
elapsed_ms = start.elapsed().as_millis() as u64,
"loaded default model"
);
}
Err(e) => {
let rendered = format!("{e:#}");
activation.fail_loading(&spec.model_id, &rendered).await;
// When the underlying failure is a preflight rejection,
// pull the structured fields out so journalctl shows
// `reason=tp_requires_safetensors detail="..."` instead
// of an opaque "fetch config.json … 404". The operator
// can act on the structured form directly.
if let Some(pf) = e.downcast_ref::<PreflightError>() {
tracing::warn!(
model = %spec.model_id,
reason = preflight_kind(pf),
detail = %pf,
elapsed_ms = start.elapsed().as_millis() as u64,
"failed to load default model, continuing"
);
} else {
tracing::warn!(
model = %spec.model_id,
error = %rendered,
elapsed_ms = start.elapsed().as_millis() as u64,
"failed to load default model, continuing"
);
}
}
}
}
activation.mark_ready().await;
}
/// Short kebab-case tag for a preflight failure. Used as a structured
/// log field so journalctl filtering can match on the failure class
/// (`reason=tp_requires_safetensors`, `reason=quant_not_found`, etc.).
fn preflight_kind(err: &PreflightError) -> &'static str {
match err {
PreflightError::RepoFetchFailed { .. } => "repo_fetch_failed",
PreflightError::EmptyRepo { .. } => "empty_repo",
PreflightError::TpRequiresSafetensors { .. } => "tp_requires_safetensors",
PreflightError::QuantNotFound { .. } => "quant_not_found",
}
}
/// Future that resolves on SIGINT (Ctrl-C) or SIGTERM (systemd stop).
///
/// Wired into `axum::serve(...).with_graceful_shutdown(shutdown_signal())`
/// so the HTTP listener stops accepting new connections, lets in-flight
/// requests drain, and then yields control back to main for cleanup.
pub async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c().await.ok();
};
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("install SIGTERM handler")
.recv()
.await;
};
tokio::select! {
_ = ctrl_c => tracing::info!("received SIGINT, shutting down"),
_ = terminate => tracing::info!("received SIGTERM, shutting down"),
}
}
/// Unload every model currently registered. Called from `main.rs` after
/// axum's graceful shutdown future resolves, so CUDA contexts and VRAM
/// are released before the process exits rather than left to the OS to
/// reclaim. Per-model failures are logged and skipped — keep cleanup
/// going even when one harness is unhealthy.
pub async fn unload_all_models(registry: &HarnessRegistry) {
let listed = match registry.list_all_models().await {
Ok(m) => m,
Err(e) => {
tracing::warn!(error = %e, "failed to list models during shutdown");
return;
}
};
if listed.is_empty() {
return;
}
tracing::info!(count = listed.len(), "unloading models for shutdown");
let mut stuck = 0;
for model in listed {
let start = Instant::now();
match tokio::time::timeout(UNLOAD_TIMEOUT, registry.unload_model(&model.id)).await {
Ok(Ok(())) => tracing::info!(
model = %model.id,
elapsed_ms = start.elapsed().as_millis() as u64,
"unloaded"
),
// Most common shape today: TP unload bails because an
// inference is still mid-flight (the spawned task holds
// an `Arc<TpLoadedModel>` clone). Promoted from warn to
// error and tagged with the request-state so the operator
// can correlate with the chat_completion logs above.
Ok(Err(e)) => {
stuck += 1;
tracing::error!(
model = %model.id,
error = %e,
elapsed_ms = start.elapsed().as_millis() as u64,
"unload failed during shutdown"
);
}
Err(_) => {
stuck += 1;
tracing::error!(
model = %model.id,
timeout_secs = UNLOAD_TIMEOUT.as_secs(),
"unload timed out during shutdown, continuing"
);
}
}
}
if stuck > 0 {
tracing::error!(
stuck,
"shutdown leaving {stuck} model(s) loaded; VRAM will be \
reclaimed by the OS on process exit"
);
}
}

View File

@@ -1,306 +0,0 @@
//! Format-agnostic inference event stream.
//!
//! The candle harness emits a sequence of these for every streaming
//! request. Wire-format projections in sibling modules
//! ([`super::openai_chat`], the eventual `openai_responses` /
//! `anthropic_messages` projections) read this stream and produce
//! the chunks / events their HTTP clients expect.
//!
//! Design notes:
//!
//! - [`Start`] carries no token of its own. It only signals "the
//! model has accepted the prompt and is about to begin emitting
//! text". OpenAI chat materialises this as a `role: assistant`
//! chunk; OpenAI Responses as the `response.created` +
//! `response.output_item.added` pair; Anthropic as
//! `message_start`. All three of those would otherwise have to
//! peek at the *first* token to know when to emit, which couples
//! the wire layer to the producer's pacing.
//! - [`TextDelta`] is *visible* output. Reasoning / `<think>`
//! blocks go through a future [`ReasoningDelta`] variant once
//! the harness learns to split them (today they pass through as
//! plain text inside `TextDelta`; helexa-acp picks them apart on
//! the consumer side).
//! - [`Finish`] is the only place a stream is allowed to end
//! cleanly. Projections rely on this to emit final usage
//! bookkeeping; absence means the producer crashed and the
//! consumer should treat the stream as truncated.
//!
//! [`Start`]: InferenceEvent::Start
//! [`TextDelta`]: InferenceEvent::TextDelta
//! [`Finish`]: InferenceEvent::Finish
/// One unit of output from the inference loop.
///
/// Producers send these on an `mpsc::Sender<InferenceEvent>`;
/// projection layers in sibling modules consume them and emit
/// wire-format-specific frames downstream.
#[derive(Debug, Clone)]
pub enum InferenceEvent {
/// The producer has accepted the prompt and is about to emit
/// the first token. Sent at most once per stream.
Start,
/// A piece of visible assistant text. Multiple deltas
/// concatenate into the complete reply.
TextDelta(String),
/// Reasoning / scratchpad text the model emitted inside a
/// `<think>` block (or equivalent). The harness routes
/// content between marker tokens here so wire projectors can
/// decide what to do with it (chat completions drops by
/// default; Responses API has a dedicated event family).
ReasoningDelta(String),
/// A tool call has been parsed out of a `<tool_call>{json}</tool_call>`
/// block. Carries the parsed name + arguments JSON string
/// (Anthropic / OpenAI projectors emit their own wire shape
/// from this).
///
/// `index` is the call slot — incremented per tool call in a
/// turn so wire formats that order calls by index
/// (OpenAI chat completions) can correlate.
ToolCall {
index: usize,
id: String,
name: String,
/// Complete JSON arguments string. The model could in
/// principle stream these token-by-token, but our
/// extraction buffers the whole block until `</tool_call>`
/// arrives and emits exactly one event per call.
arguments: String,
},
/// The stream is complete. Carries the reason so wire formats
/// that use it (OpenAI's `finish_reason`, Anthropic's
/// `stop_reason`) can render it without re-parsing.
Finish { reason: FinishReason },
}
/// Why a stream stopped. Stays small on purpose — anything that
/// doesn't map cleanly to one of these collapses to [`Stop`].
///
/// Mappings to wire formats:
///
/// | variant | OpenAI `finish_reason` | OpenAI Responses `status` | Anthropic `stop_reason` |
/// |---------|------------------------|---------------------------|-------------------------|
/// | `Stop` | `"stop"` | `"completed"` | `"end_turn"` |
/// | `Length`| `"length"` | `"incomplete"` | `"max_tokens"` |
/// | `ToolCalls` | `"tool_calls"` | `"completed"` | `"tool_use"` |
///
/// [`Stop`]: FinishReason::Stop
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FinishReason {
/// Model emitted EOS naturally.
Stop,
/// Hit `max_tokens` before EOS.
Length,
/// Stopped because the model called a tool and is waiting for
/// the result. Not yet emitted by the candle harness —
/// reserved for the day tool-call extraction lands.
#[allow(dead_code)]
ToolCalls,
}
impl FinishReason {
/// String form used by OpenAI chat completions and OpenAI
/// completions. Wire modules can call this directly or do their
/// own mapping for non-string formats.
pub fn as_openai_str(self) -> &'static str {
match self {
FinishReason::Stop => "stop",
FinishReason::Length => "length",
FinishReason::ToolCalls => "tool_calls",
}
}
}
/// Open/close token IDs for the reasoning marker a loaded model uses
/// (or `None` for non-reasoning models). The harness reads this once
/// at load time from the tokenizer's added-tokens table, then the
/// inference loop checks `next_token` against the pair to flip
/// between [`InferenceEvent::TextDelta`] and
/// [`InferenceEvent::ReasoningDelta`].
///
/// `open` and `close` text are kept alongside the IDs so wire
/// projectors that want to re-emit the literal markers (the
/// opt-in `include_thinking` path on chat completions) don't have
/// to reach back into the tokenizer for the strings.
#[derive(Debug, Clone)]
pub struct ReasoningTokenPair {
pub open_id: u32,
pub close_id: u32,
pub open_text: String,
pub close_text: String,
}
/// Known reasoning-marker conventions. Each is a `(open, close)`
/// pair of literal token strings. Each modern reasoning model
/// declares its markers in the tokenizer's `added_tokens` table;
/// at load time we probe for whichever pair the loaded tokenizer
/// has and stash both IDs.
///
/// Ordering matters only for tie-breaking when a model declares
/// multiple pairs (shouldn't happen in practice); the first hit
/// wins.
const KNOWN_REASONING_MARKERS: &[(&str, &str)] = &[
// Qwen3, DeepSeek-R1, gpt-oss, and most other open-weight
// reasoning models.
("<think>", "</think>"),
// Mistral Magistral.
("[THINK]", "[/THINK]"),
// Some older derivatives; harmless to probe.
("<thought>", "</thought>"),
("<reasoning>", "</reasoning>"),
];
/// Open/close token IDs for the model's tool-call marker
/// convention (or `None` for models that don't emit structured
/// tool calls). Same shape as [`ReasoningTokenPair`]: probed once
/// at load time, consumed by the inference loop to switch between
/// "emit visible deltas" and "buffer JSON for the next tool
/// call".
#[derive(Debug, Clone)]
pub struct ToolCallTokenPair {
pub open_id: u32,
pub close_id: u32,
pub open_text: String,
pub close_text: String,
}
/// Tool-call marker conventions. Open-weight tool-use models
/// converged on `<tool_call>` / `</tool_call>` (Qwen3-Coder /
/// -Instruct, the Hermes function-call format, DeepSeek-Coder,
/// gpt-oss). The pair lives alongside the reasoning markers in
/// the same `added_tokens` table.
const KNOWN_TOOL_CALL_MARKERS: &[(&str, &str)] = &[("<tool_call>", "</tool_call>")];
/// Probe a tokenizer for known tool-call marker pairs. Mirrors
/// [`detect_reasoning_token_pair`] — both open AND close must
/// resolve for the pair to be returned. `None` means the model
/// doesn't emit structured tool calls (or its tokenizer split
/// the markers across tokens).
pub fn detect_tool_call_token_pair<F>(token_to_id: F) -> Option<ToolCallTokenPair>
where
F: Fn(&str) -> Option<u32>,
{
for (open_text, close_text) in KNOWN_TOOL_CALL_MARKERS {
let open_id = token_to_id(open_text);
let close_id = token_to_id(close_text);
if let (Some(open_id), Some(close_id)) = (open_id, close_id) {
return Some(ToolCallTokenPair {
open_id,
close_id,
open_text: (*open_text).into(),
close_text: (*close_text).into(),
});
}
}
None
}
/// Inspect a tokenizer for known reasoning-marker pairs and return
/// the first match. The tokenizer types this trait is defined over
/// just need to expose `token_to_id(&str) -> Option<u32>` so this
/// stays decoupled from the candle crate — the production caller
/// passes a `tokenizers::Tokenizer`, but tests can fake one.
///
/// Returns `None` when no known marker pair is fully declared
/// (both open AND close token ids must resolve). That's the
/// pass-through case — non-reasoning models, or reasoning models
/// whose tokenizer split the markers across multiple tokens (rare
/// in practice; modern reasoning tokenizers list them as
/// `added_tokens`).
pub fn detect_reasoning_token_pair<F>(token_to_id: F) -> Option<ReasoningTokenPair>
where
F: Fn(&str) -> Option<u32>,
{
for (open_text, close_text) in KNOWN_REASONING_MARKERS {
let open_id = token_to_id(open_text);
let close_id = token_to_id(close_text);
if let (Some(open_id), Some(close_id)) = (open_id, close_id) {
return Some(ReasoningTokenPair {
open_id,
close_id,
open_text: (*open_text).into(),
close_text: (*close_text).into(),
});
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn lookup<'a>(map: &'a HashMap<&'static str, u32>) -> impl Fn(&str) -> Option<u32> + 'a {
|s| map.get(s).copied()
}
#[test]
fn detects_qwen3_style_think_markers() {
let mut m = HashMap::new();
m.insert("<think>", 151648);
m.insert("</think>", 151649);
let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected");
assert_eq!(pair.open_id, 151648);
assert_eq!(pair.close_id, 151649);
assert_eq!(pair.open_text, "<think>");
assert_eq!(pair.close_text, "</think>");
}
#[test]
fn detects_mistral_magistral_markers() {
let mut m = HashMap::new();
m.insert("[THINK]", 100);
m.insert("[/THINK]", 101);
let pair = detect_reasoning_token_pair(lookup(&m)).expect("pair detected");
assert_eq!(pair.open_text, "[THINK]");
}
#[test]
fn returns_none_when_only_open_marker_present() {
// A pathological tokenizer that has `<think>` but not
// `</think>` shouldn't half-detect. Pass-through.
let mut m = HashMap::new();
m.insert("<think>", 1);
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
}
#[test]
fn returns_none_for_non_reasoning_tokenizer() {
let m: HashMap<&'static str, u32> = HashMap::new();
assert!(detect_reasoning_token_pair(lookup(&m)).is_none());
}
#[test]
fn detects_tool_call_markers() {
let mut m = HashMap::new();
m.insert("<tool_call>", 151657);
m.insert("</tool_call>", 151658);
let pair = detect_tool_call_token_pair(lookup(&m)).expect("pair detected");
assert_eq!(pair.open_id, 151657);
assert_eq!(pair.close_id, 151658);
assert_eq!(pair.open_text, "<tool_call>");
assert_eq!(pair.close_text, "</tool_call>");
}
#[test]
fn returns_none_for_non_tool_use_tokenizer() {
let m: HashMap<&'static str, u32> = HashMap::new();
assert!(detect_tool_call_token_pair(lookup(&m)).is_none());
}
#[test]
fn first_match_wins_when_multiple_pairs_declared() {
// Hypothetical tokenizer with both Qwen-style AND Mistral-style
// markers — the `<think>` pair is earlier in the convention
// table so it wins.
let mut m = HashMap::new();
m.insert("<think>", 1);
m.insert("</think>", 2);
m.insert("[THINK]", 3);
m.insert("[/THINK]", 4);
let pair = detect_reasoning_token_pair(lookup(&m)).unwrap();
assert_eq!(pair.open_id, 1);
assert_eq!(pair.close_id, 2);
}
}

View File

@@ -1,27 +0,0 @@
//! Wire-format projection layer.
//!
//! The candle harness produces a single, format-agnostic stream of
//! [`InferenceEvent`]s. Each wire format (OpenAI chat completions,
//! OpenAI Responses, Anthropic messages, …) lives in its own module
//! under `wire::` and projects that event stream into the chunks /
//! events its HTTP clients expect.
//!
//! The benefit over translating *between* wire shapes (OpenAI chat
//! → Anthropic, etc.) is that we never have to reason about a
//! wire-N → wire-M conversion: every translation is wire-N ↔ the
//! internal event currency, and the projections are independent. A
//! new wire format adds a new file under `wire::`; nothing else
//! needs to know about it.
//!
//! Today: [`openai_chat`]. Stage 2 adds `openai_responses`. Stage 3
//! could add a native Anthropic projection that replaces the
//! gateway-side translation.
pub mod event;
pub mod openai_chat;
pub mod openai_responses;
pub use event::{
FinishReason, InferenceEvent, ReasoningTokenPair, ToolCallTokenPair,
detect_reasoning_token_pair, detect_tool_call_token_pair,
};

View File

@@ -1,558 +0,0 @@
//! OpenAI chat completions projection.
//!
//! Reads [`InferenceEvent`]s from a receiver and produces
//! [`ChatCompletionChunk`]s in the shape `POST /v1/chat/completions`
//! clients expect on its streaming SSE response. The HTTP handler in
//! [`crate::api`] wraps the resulting receiver in axum's
//! `Sse::new(...)` adapter; nothing in this module touches HTTP
//! framing or `data:` lines.
//!
//! Per the OpenAI streaming spec, three chunk shapes appear:
//!
//! 1. **Role chunk** — `delta: { "role": "assistant" }`, no content,
//! sent once at stream start. We emit this on [`InferenceEvent::Start`].
//! 2. **Content chunks** — `delta: { "content": "<text>" }`, one per
//! [`InferenceEvent::TextDelta`].
//! 3. **Final chunk** — empty `delta`, `finish_reason` populated.
//! Emitted on [`InferenceEvent::Finish`].
//!
//! `usage` stays `None` on every chunk; the legacy candle paths
//! never surfaced usage on the streaming endpoint and we keep that
//! behaviour bit-for-bit so existing clients see no diff.
//!
//! Back-pressure: the projection task awaits both `rx.recv()` and
//! `tx.send()`. A slow consumer fills the output channel → the
//! task blocks on send → it stops reading from the input → the
//! producer blocks on its own send. The bounded channels
//! propagate without us writing any logic.
use cortex_core::openai::{ChatCompletionChunk, ChunkChoice};
use serde_json::json;
use tokio::sync::mpsc;
use super::event::{FinishReason, InferenceEvent, ReasoningTokenPair};
/// Output channel buffer size. Mirrors the input side's bound; one
/// event maps to at most one chunk, so equal capacity keeps the
/// two ends in sync without surprising memory growth.
const CHUNK_CHANNEL_CAPACITY: usize = 32;
/// Per-stream config for the chat projector. Used by the
/// production handler to thread per-request choices (currently:
/// whether to surface reasoning content) into the projection
/// without bloating the function signature.
#[derive(Debug, Clone, Default)]
pub struct ChatProjectionConfig {
/// When `true`, reasoning content is re-wrapped with the
/// model's literal open/close markers and emitted as content
/// deltas — preserving the on-the-wire shape that
/// reasoning-aware clients like helexa-acp's `ThinkParser`
/// expect.
///
/// When `false` (the default), [`InferenceEvent::ReasoningDelta`]s
/// are dropped entirely so consumers that don't know about
/// reasoning (Zed's commit-message generator, any vanilla
/// OpenAI client) don't have model-internal scratchpad
/// material leaking into their UI. The chat-completions wire
/// format has no slot for reasoning, so the default chooses
/// the safer-for-naïve-clients behaviour.
pub include_thinking: bool,
/// Open/close marker strings to re-emit when `include_thinking`
/// is set. Sourced from the loaded model's
/// [`ReasoningTokenPair`]; `None` for non-reasoning models or
/// when the caller doesn't have the pair handy (in which case
/// `include_thinking` becomes equivalent to dropping reasoning
/// because there's nothing to wrap).
pub reasoning_markers: Option<ReasoningTokenPair>,
}
/// Project an [`InferenceEvent`] receiver into a
/// [`ChatCompletionChunk`] receiver. Spawns one tokio task that
/// owns the input receiver for the stream's lifetime and exits
/// when either side closes.
///
/// `id`, `created`, and `model_id` are stamped into every emitted
/// chunk so the receiver can stay generic (decoupled from
/// per-request metadata).
pub fn project_chat_stream(
rx: mpsc::Receiver<InferenceEvent>,
id: String,
created: u64,
model_id: String,
) -> mpsc::Receiver<ChatCompletionChunk> {
// Default config: include_thinking off, no marker rewrap.
project_chat_stream_with(rx, id, created, model_id, ChatProjectionConfig::default())
}
/// Same as [`project_chat_stream`] but with a per-stream config
/// (currently controlling reasoning surfacing). Production
/// callers that need the opt-in path call this directly; the
/// shorter wrapper above stays as the no-config convenience.
pub fn project_chat_stream_with(
mut rx: mpsc::Receiver<InferenceEvent>,
id: String,
created: u64,
model_id: String,
config: ChatProjectionConfig,
) -> mpsc::Receiver<ChatCompletionChunk> {
let (tx, out_rx) = mpsc::channel::<ChatCompletionChunk>(CHUNK_CHANNEL_CAPACITY);
tokio::spawn(async move {
// Track whether the previous event was inside a reasoning
// block — used to decide when to emit the literal close
// marker on the include_thinking re-wrap path. When this
// flips from true → false (a TextDelta or Finish lands
// after one or more ReasoningDeltas), we emit the close
// marker exactly once.
let mut was_in_reasoning = false;
while let Some(event) = rx.recv().await {
// Close-marker insertion: if we're leaving a reasoning
// chain, emit the literal close marker before the
// current event.
if was_in_reasoning && !matches!(event, InferenceEvent::ReasoningDelta(_)) {
if let Some(marker) = config
.include_thinking
.then_some(())
.and(config.reasoning_markers.as_ref())
{
let chunk = content_chunk(&id, created, &model_id, &marker.close_text);
if tx.send(chunk).await.is_err() {
return;
}
}
was_in_reasoning = false;
}
let chunks = match event {
InferenceEvent::Start => vec![role_chunk(&id, created, &model_id)],
InferenceEvent::TextDelta(text) => {
if text.is_empty() {
// DecodeStream is buffering a multi-byte
// codepoint; don't bother sending an empty
// chunk downstream.
continue;
}
vec![content_chunk(&id, created, &model_id, &text)]
}
InferenceEvent::ReasoningDelta(text) => {
if !config.include_thinking {
// Default path — reasoning has no slot in
// chat completions, so it's dropped. Naïve
// clients (Zed commit-message generator,
// any vanilla OpenAI client) get clean
// output.
continue;
}
let Some(markers) = config.reasoning_markers.as_ref() else {
// Caller asked to include thinking but
// didn't supply markers — best we can do
// is emit the content as visible text.
// Skip the wrap entirely.
if text.is_empty() {
continue;
}
let chunk = content_chunk(&id, created, &model_id, &text);
if tx.send(chunk).await.is_err() {
return;
}
continue;
};
// First chunk of a reasoning block → open
// marker prelude. Subsequent reasoning deltas
// in the same block reuse `was_in_reasoning`
// to skip the prelude.
let mut chunks = Vec::new();
if !was_in_reasoning {
chunks.push(content_chunk(&id, created, &model_id, &markers.open_text));
}
if !text.is_empty() {
chunks.push(content_chunk(&id, created, &model_id, &text));
}
was_in_reasoning = true;
chunks
}
InferenceEvent::ToolCall {
index,
id: call_id,
name,
arguments,
} => {
// OpenAI streaming shape for tool calls:
// `delta.tool_calls[]` with id + function.name
// on the first chunk per index, then
// function.arguments deltas. We have the
// complete arguments buffered already, so one
// delta carries everything.
vec![tool_call_chunk(
&id, created, &model_id, index, &call_id, &name, &arguments,
)]
}
InferenceEvent::Finish { reason } => {
vec![final_chunk(&id, created, &model_id, reason)]
}
};
for chunk in chunks {
if tx.send(chunk).await.is_err() {
// Consumer hung up; nothing more to do.
return;
}
}
}
});
out_rx
}
fn role_chunk(id: &str, created: u64, model_id: &str) -> ChatCompletionChunk {
ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({ "role": "assistant" }),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
}
}
fn content_chunk(id: &str, created: u64, model_id: &str, text: &str) -> ChatCompletionChunk {
ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({ "content": text }),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
}
}
/// OpenAI chat streaming shape for a tool call. One chunk per
/// call slot, carrying id + name + the complete arguments JSON.
/// Mirrors the format real OpenAI emits on the streaming path,
/// minus the per-token arguments-streaming complication (we have
/// the whole buffer already after the model finishes the
/// `<tool_call>...</tool_call>` block).
fn tool_call_chunk(
id: &str,
created: u64,
model_id: &str,
index: usize,
call_id: &str,
name: &str,
arguments: &str,
) -> ChatCompletionChunk {
ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: json!({
"tool_calls": [{
"index": index,
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}],
}),
finish_reason: None,
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
}
}
fn final_chunk(
id: &str,
created: u64,
model_id: &str,
reason: FinishReason,
) -> ChatCompletionChunk {
ChatCompletionChunk {
id: id.into(),
object: "chat.completion.chunk".into(),
created,
model: model_id.into(),
choices: vec![ChunkChoice {
index: 0,
delta: serde_json::Value::Object(Default::default()),
finish_reason: Some(reason.as_openai_str().to_string()),
extra: serde_json::Value::Object(Default::default()),
}],
usage: None,
extra: serde_json::Value::Object(Default::default()),
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Drain the projection's output into a Vec for assertion.
async fn collect(mut rx: mpsc::Receiver<ChatCompletionChunk>) -> Vec<ChatCompletionChunk> {
let mut out = Vec::new();
while let Some(chunk) = rx.recv().await {
out.push(chunk);
}
out
}
#[tokio::test]
async fn empty_event_stream_yields_no_chunks() {
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
drop(tx);
let out = collect(project_chat_stream(rx, "id-1".into(), 1700, "m".into())).await;
assert!(out.is_empty());
}
#[tokio::test]
async fn start_text_finish_produces_three_chunks() {
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
let out_rx = project_chat_stream(rx, "id-1".into(), 1700, "m".into());
tx.send(InferenceEvent::Start).await.unwrap();
tx.send(InferenceEvent::TextDelta("hello".into()))
.await
.unwrap();
tx.send(InferenceEvent::Finish {
reason: FinishReason::Stop,
})
.await
.unwrap();
drop(tx);
let out = collect(out_rx).await;
assert_eq!(out.len(), 3);
assert_eq!(out[0].choices[0].delta["role"], "assistant");
assert_eq!(out[1].choices[0].delta["content"], "hello");
assert_eq!(out[2].choices[0].finish_reason.as_deref(), Some("stop"));
// Every chunk carries the stamped metadata.
for chunk in &out {
assert_eq!(chunk.id, "id-1");
assert_eq!(chunk.created, 1700);
assert_eq!(chunk.model, "m");
assert_eq!(chunk.object, "chat.completion.chunk");
}
}
#[tokio::test]
async fn empty_text_delta_is_dropped() {
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
let out_rx = project_chat_stream(rx, "id".into(), 1, "m".into());
tx.send(InferenceEvent::TextDelta(String::new()))
.await
.unwrap();
drop(tx);
let out = collect(out_rx).await;
assert!(out.is_empty(), "empty deltas must not produce chunks");
}
#[tokio::test]
async fn finish_length_maps_to_openai_string() {
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
let out_rx = project_chat_stream(rx, "id".into(), 1, "m".into());
tx.send(InferenceEvent::Finish {
reason: FinishReason::Length,
})
.await
.unwrap();
drop(tx);
let out = collect(out_rx).await;
assert_eq!(out.len(), 1);
assert_eq!(out[0].choices[0].finish_reason.as_deref(), Some("length"));
}
#[tokio::test]
async fn reasoning_delta_is_dropped_in_chat_projection() {
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
let out_rx = project_chat_stream(rx, "id".into(), 1, "m".into());
tx.send(InferenceEvent::ReasoningDelta("<think>".into()))
.await
.unwrap();
tx.send(InferenceEvent::TextDelta("real".into()))
.await
.unwrap();
drop(tx);
let out = collect(out_rx).await;
assert_eq!(out.len(), 1);
assert_eq!(out[0].choices[0].delta["content"], "real");
}
fn pair() -> ReasoningTokenPair {
ReasoningTokenPair {
open_id: 0,
close_id: 1,
open_text: "<think>".into(),
close_text: "</think>".into(),
}
}
#[tokio::test]
async fn include_thinking_rewraps_reasoning_with_literal_markers() {
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
let out_rx = project_chat_stream_with(
rx,
"id".into(),
1,
"m".into(),
ChatProjectionConfig {
include_thinking: true,
reasoning_markers: Some(pair()),
},
);
tx.send(InferenceEvent::ReasoningDelta("first ".into()))
.await
.unwrap();
tx.send(InferenceEvent::ReasoningDelta("second".into()))
.await
.unwrap();
tx.send(InferenceEvent::TextDelta("answer".into()))
.await
.unwrap();
tx.send(InferenceEvent::Finish {
reason: FinishReason::Stop,
})
.await
.unwrap();
drop(tx);
let out = collect(out_rx).await;
// Expected sequence: open marker → reasoning content (2 chunks)
// → close marker → visible answer → final chunk.
let contents: Vec<&str> = out
.iter()
.filter_map(|c| c.choices[0].delta["content"].as_str())
.collect();
assert_eq!(
contents,
vec!["<think>", "first ", "second", "</think>", "answer"]
);
assert_eq!(
out.last().unwrap().choices[0].finish_reason.as_deref(),
Some("stop")
);
}
#[tokio::test]
async fn include_thinking_closes_marker_at_finish_when_no_trailing_text() {
// Edge case: stream ends inside a reasoning block (model
// hit max_tokens mid-thought, no visible answer ever).
// The Finish event still triggers the close marker so the
// stream is balanced.
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
let out_rx = project_chat_stream_with(
rx,
"id".into(),
1,
"m".into(),
ChatProjectionConfig {
include_thinking: true,
reasoning_markers: Some(pair()),
},
);
tx.send(InferenceEvent::ReasoningDelta("thinking...".into()))
.await
.unwrap();
tx.send(InferenceEvent::Finish {
reason: FinishReason::Length,
})
.await
.unwrap();
drop(tx);
let out = collect(out_rx).await;
let contents: Vec<&str> = out
.iter()
.filter_map(|c| c.choices[0].delta["content"].as_str())
.collect();
assert_eq!(contents, vec!["<think>", "thinking...", "</think>"]);
assert_eq!(
out.last().unwrap().choices[0].finish_reason.as_deref(),
Some("length")
);
}
#[tokio::test]
async fn include_thinking_without_markers_emits_content_directly() {
// Defensive: if the caller asks for thinking but the
// model declared no markers, we still emit the content
// rather than dropping it. Better to leak than to lose.
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
let out_rx = project_chat_stream_with(
rx,
"id".into(),
1,
"m".into(),
ChatProjectionConfig {
include_thinking: true,
reasoning_markers: None,
},
);
tx.send(InferenceEvent::ReasoningDelta("raw".into()))
.await
.unwrap();
tx.send(InferenceEvent::Finish {
reason: FinishReason::Stop,
})
.await
.unwrap();
drop(tx);
let out = collect(out_rx).await;
let contents: Vec<&str> = out
.iter()
.filter_map(|c| c.choices[0].delta["content"].as_str())
.collect();
assert_eq!(contents, vec!["raw"]);
}
#[tokio::test]
async fn include_thinking_off_drops_reasoning_even_with_markers() {
// Default behaviour even when markers happen to be
// configured. The flag is the gate, not the marker
// presence.
let (tx, rx) = mpsc::channel::<InferenceEvent>(4);
let out_rx = project_chat_stream_with(
rx,
"id".into(),
1,
"m".into(),
ChatProjectionConfig {
include_thinking: false,
reasoning_markers: Some(pair()),
},
);
tx.send(InferenceEvent::ReasoningDelta("hidden".into()))
.await
.unwrap();
tx.send(InferenceEvent::TextDelta("visible".into()))
.await
.unwrap();
tx.send(InferenceEvent::Finish {
reason: FinishReason::Stop,
})
.await
.unwrap();
drop(tx);
let out = collect(out_rx).await;
let contents: Vec<&str> = out
.iter()
.filter_map(|c| c.choices[0].delta["content"].as_str())
.collect();
assert_eq!(contents, vec!["visible"]);
}
}

View File

@@ -1,870 +0,0 @@
//! OpenAI Responses API projection.
//!
//! Two responsibilities:
//!
//! 1. **Translate request shape**: [`request_to_chat`] flattens
//! [`ResponsesRequest`]'s typed `input` items + `instructions`
//! into the [`ChatCompletionRequest`] the candle harness already
//! knows how to run. The Responses-specific shape stops at this
//! function — everything downstream is the same chat path the
//! `/v1/chat/completions` route exercises.
//!
//! 2. **Project event stream**: [`project_responses_stream`] reads
//! [`InferenceEvent`]s from the harness and emits the named SSE
//! events the Responses API client expects
//! (`response.created`, `response.output_text.delta`,
//! `response.completed`, …) along with their JSON payloads.
//! The HTTP handler in [`crate::api`] reads
//! `(event_name, data)` tuples off the receiver and stamps them
//! onto axum SSE frames.
//!
//! Scope cuts (carried over from [`cortex_core::responses`]):
//!
//! - `previous_response_id` is rejected by [`request_to_chat`]
//! with [`TranslateError::ChainedConversationNotSupported`].
//! - `Reasoning` input items are dropped (no equivalent in chat).
//! - `FunctionCall` / `FunctionCallOutput` items round-trip but the
//! harness never emits tool calls today; the synthesis paths are
//! in place so the surface is ready when it does.
use cortex_core::openai::{ChatCompletionRequest, ChatMessage, MessageContent};
use cortex_core::responses::{
ResponsesContentPart, ResponsesInput, ResponsesInputItem, ResponsesMessageContent,
ResponsesOutputContent, ResponsesOutputItem, ResponsesRequest, ResponsesResponse,
ResponsesUsage, events,
};
use serde_json::{Value, json};
use tokio::sync::mpsc;
use super::event::{FinishReason, InferenceEvent};
/// Per-request metadata that has to be stamped into every emitted
/// event. The projector spawns a task that owns one of these.
#[derive(Debug, Clone)]
pub struct ResponseMeta {
pub response_id: String,
pub created_at: u64,
pub model_id: String,
/// Item id used inside `output[0]` (the message). All
/// `content_part.*` and `output_text.*` events reference this
/// so the consumer knows which item the delta belongs to.
pub message_item_id: String,
}
/// Reasons [`request_to_chat`] refuses a request.
#[derive(Debug, thiserror::Error)]
pub enum TranslateError {
#[error(
"previous_response_id is not supported on this neuron; chained \
conversations require server-side state we don't store yet"
)]
ChainedConversationNotSupported,
}
/// Flatten a [`ResponsesRequest`] into the chat-completions shape
/// the candle harness already knows how to drive. Keeps the
/// Responses-specific machinery contained to a single function so
/// the harness stays format-agnostic.
///
/// Semantics:
///
/// - `instructions` (if set) becomes a leading `system` message.
/// - `input: "<string>"` becomes a single `user` message.
/// - `input: [items]` flattens each item:
/// - `Message { role, content }` → one `ChatMessage`.
/// - `FunctionCall` → an `assistant` turn whose `extra.tool_calls`
/// carries the call (chat-completions-shaped). The harness
/// doesn't act on tool_calls today, but the shape stays
/// consistent with what chat would expect.
/// - `FunctionCallOutput` → a `tool` role message with the
/// output text. Matches OpenAI's chat convention.
/// - `Reasoning` items are dropped (no equivalent in chat).
/// - Text parts within an array `content` collapse to a single
/// string; image parts get rendered as a chat-style content
/// array `[{type:"text"}, {type:"image_url"}]` so the chat
/// handler's existing vision path applies.
pub fn request_to_chat(req: ResponsesRequest) -> Result<ChatCompletionRequest, TranslateError> {
if req.previous_response_id.is_some() {
return Err(TranslateError::ChainedConversationNotSupported);
}
let mut messages: Vec<ChatMessage> = Vec::new();
if let Some(instructions) = req.instructions
&& !instructions.is_empty()
{
messages.push(ChatMessage {
role: "system".into(),
content: MessageContent::Text(instructions),
extra: Value::Object(Default::default()),
});
}
match req.input {
ResponsesInput::Text(text) => {
messages.push(ChatMessage {
role: "user".into(),
content: MessageContent::Text(text),
extra: Value::Object(Default::default()),
});
}
ResponsesInput::Items(items) => {
for item in items {
if let Some(msg) = input_item_to_chat(item) {
messages.push(msg);
}
}
}
}
Ok(ChatCompletionRequest {
model: req.model,
messages,
temperature: req.temperature,
top_p: req.top_p,
max_tokens: req.max_output_tokens,
stream: Some(req.stream),
extra: Value::Object(Default::default()),
})
}
fn input_item_to_chat(item: ResponsesInputItem) -> Option<ChatMessage> {
match item {
ResponsesInputItem::Message { role, content } => Some(ChatMessage {
role,
content: message_content_to_chat(content),
extra: Value::Object(Default::default()),
}),
ResponsesInputItem::FunctionCall {
call_id,
name,
arguments,
} => {
// Express the call in chat-completions shape via
// `extra.tool_calls`. The harness ignores it today but
// the shape is consistent for the day it doesn't.
let mut extra = serde_json::Map::new();
extra.insert(
"tool_calls".into(),
json!([{
"id": call_id,
"type": "function",
"function": { "name": name, "arguments": arguments },
}]),
);
Some(ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(String::new()),
extra: Value::Object(extra),
})
}
ResponsesInputItem::FunctionCallOutput { call_id, output } => {
let mut extra = serde_json::Map::new();
extra.insert("tool_call_id".into(), Value::String(call_id));
Some(ChatMessage {
role: "tool".into(),
content: MessageContent::Text(output),
extra: Value::Object(extra),
})
}
// Reasoning items don't have a chat-completions equivalent
// we can faithfully forward. Silently drop — the alternative
// is rejecting a well-formed request, which is worse UX.
ResponsesInputItem::Reasoning { .. } => None,
}
}
fn message_content_to_chat(content: ResponsesMessageContent) -> MessageContent {
match content {
ResponsesMessageContent::Text(s) => MessageContent::Text(s),
ResponsesMessageContent::Parts(parts) => {
// Collapse to a string when every part is text; emit
// the chat content-array shape only when an image is
// present (some upstreams treat the array form as a
// vision-only signal and reject it for text-only
// models).
let has_image = parts
.iter()
.any(|p| matches!(p, ResponsesContentPart::InputImage { .. }));
if !has_image {
let joined = parts
.into_iter()
.filter_map(|p| match p {
ResponsesContentPart::InputText { text }
| ResponsesContentPart::OutputText { text, .. } => Some(text),
ResponsesContentPart::InputImage { .. } => None,
})
.collect::<Vec<_>>()
.join("\n\n");
return MessageContent::Text(joined);
}
let mut out: Vec<Value> = Vec::with_capacity(parts.len());
for p in parts {
match p {
ResponsesContentPart::InputText { text }
| ResponsesContentPart::OutputText { text, .. } => {
out.push(json!({ "type": "text", "text": text }));
}
ResponsesContentPart::InputImage { image_url, .. } => {
out.push(json!({
"type": "image_url",
"image_url": { "url": image_url },
}));
}
}
}
MessageContent::Parts(out)
}
}
}
// ── Streaming projection ─────────────────────────────────────────────
/// One frame the projector emits. The HTTP handler maps each into
/// an axum `Sse::Event` with both an `event:` name and a `data:`
/// JSON payload — Responses, unlike chat completions, uses named
/// SSE events.
#[derive(Debug, Clone)]
pub struct ResponseStreamFrame {
pub event_name: &'static str,
pub data: Value,
}
/// Project an [`InferenceEvent`] receiver into a stream of
/// [`ResponseStreamFrame`]s. The emitted sequence per stream is:
///
/// 1. `response.created` — shell with `status: "in_progress"`.
/// 2. `response.output_item.added` — empty message item.
/// 3. `response.content_part.added` — empty `output_text` part.
/// 4. `response.output_text.delta` × N — token-by-token text.
/// 5. `response.output_text.done` — full accumulated text.
/// 6. `response.content_part.done` — full part payload.
/// 7. `response.output_item.done` — full message item.
/// 8. `response.completed` — final response with `status:"completed"`.
///
/// Empty TextDeltas (the harness's incomplete-UTF-8 buffering) are
/// dropped. `ReasoningDelta`s have no representation in the
/// Responses API spec we model yet, so they're dropped too.
pub fn project_responses_stream(
rx: mpsc::Receiver<InferenceEvent>,
meta: ResponseMeta,
) -> mpsc::Receiver<ResponseStreamFrame> {
let (tx, out_rx) = mpsc::channel::<ResponseStreamFrame>(64);
tokio::spawn(async move {
run_projection(rx, meta, tx).await;
});
out_rx
}
async fn run_projection(
mut rx: mpsc::Receiver<InferenceEvent>,
meta: ResponseMeta,
tx: mpsc::Sender<ResponseStreamFrame>,
) {
let mut accumulated = String::new();
let mut finish: Option<FinishReason> = None;
let mut emitted_start = false;
while let Some(event) = rx.recv().await {
match event {
InferenceEvent::Start => {
emitted_start = true;
if !emit_start_frames(&tx, &meta).await {
return;
}
}
InferenceEvent::TextDelta(text) => {
if text.is_empty() {
continue;
}
accumulated.push_str(&text);
let frame = ResponseStreamFrame {
event_name: events::OUTPUT_TEXT_DELTA,
data: json!({
"item_id": meta.message_item_id,
"output_index": 0,
"content_index": 0,
"delta": text,
}),
};
if tx.send(frame).await.is_err() {
return;
}
}
InferenceEvent::ReasoningDelta(_) => {
// No representation in our Responses model yet.
// Stage where it'd land: a `response.reasoning_*`
// event family alongside `response.output_text.*`.
}
InferenceEvent::ToolCall { .. } => {
// Responses-side tool-call routing not wired yet
// (would emit response.function_call_arguments.*
// events). Drop for now; the chat-completions
// projector handles tool calls. Future work
// tracked in #7 alongside the in_progress event.
}
InferenceEvent::Finish { reason } => {
finish = Some(reason);
}
}
}
// Producers can drop without ever sending Start (e.g. early
// poisoned-model error). Synthesize the open frames so the
// consumer at least sees a coherent shell before completed.
if !emitted_start && !emit_start_frames(&tx, &meta).await {
return;
}
let reason = finish.unwrap_or(FinishReason::Stop);
let _ = emit_finish_frames(&tx, &meta, &accumulated, reason).await;
}
async fn emit_start_frames(tx: &mpsc::Sender<ResponseStreamFrame>, meta: &ResponseMeta) -> bool {
let shell = response_shell(meta, "in_progress", &[], None);
let frames = [
ResponseStreamFrame {
event_name: events::CREATED,
data: json!({ "response": shell.clone() }),
},
// `response.in_progress` carries the same shell as
// `response.created` — both report the "in_progress"
// status and both are payload-light bookkeeping events.
// The distinction is meaningful to clients that
// differentiate "request validated" from "model is
// generating" in their UI (loading spinner vs streaming
// spinner). OpenAI's own Responses SSE emits them as a
// pair; matching the wire shape avoids subtle client
// breakage.
ResponseStreamFrame {
event_name: events::IN_PROGRESS,
data: json!({ "response": shell }),
},
ResponseStreamFrame {
event_name: events::OUTPUT_ITEM_ADDED,
data: json!({
"output_index": 0,
"item": empty_message_item(&meta.message_item_id),
}),
},
ResponseStreamFrame {
event_name: events::CONTENT_PART_ADDED,
data: json!({
"item_id": meta.message_item_id,
"output_index": 0,
"content_index": 0,
"part": { "type": "output_text", "text": "", "annotations": [] },
}),
},
];
for frame in frames {
if tx.send(frame).await.is_err() {
return false;
}
}
true
}
async fn emit_finish_frames(
tx: &mpsc::Sender<ResponseStreamFrame>,
meta: &ResponseMeta,
full_text: &str,
reason: FinishReason,
) -> bool {
let status = finish_to_status(reason);
let full_part = json!({
"type": "output_text",
"text": full_text,
"annotations": [],
});
let full_item = json!({
"type": "message",
"id": meta.message_item_id,
"role": "assistant",
"content": [full_part.clone()],
"status": status,
});
let frames = [
ResponseStreamFrame {
event_name: events::OUTPUT_TEXT_DONE,
data: json!({
"item_id": meta.message_item_id,
"output_index": 0,
"content_index": 0,
"text": full_text,
}),
},
ResponseStreamFrame {
event_name: events::CONTENT_PART_DONE,
data: json!({
"item_id": meta.message_item_id,
"output_index": 0,
"content_index": 0,
"part": full_part,
}),
},
ResponseStreamFrame {
event_name: events::OUTPUT_ITEM_DONE,
data: json!({
"output_index": 0,
"item": full_item.clone(),
}),
},
ResponseStreamFrame {
event_name: events::COMPLETED,
data: json!({
"response": response_shell(meta, status, &[full_item], None)
}),
},
];
for frame in frames {
if tx.send(frame).await.is_err() {
return false;
}
}
true
}
fn response_shell(
meta: &ResponseMeta,
status: &str,
output: &[Value],
usage: Option<&ResponsesUsage>,
) -> Value {
let mut obj = serde_json::Map::new();
obj.insert("id".into(), Value::String(meta.response_id.clone()));
obj.insert("object".into(), Value::String("response".into()));
obj.insert("created_at".into(), json!(meta.created_at));
obj.insert("status".into(), Value::String(status.into()));
obj.insert("model".into(), Value::String(meta.model_id.clone()));
obj.insert("output".into(), Value::Array(output.to_vec()));
if let Some(u) = usage {
obj.insert(
"usage".into(),
json!({
"input_tokens": u.input_tokens,
"output_tokens": u.output_tokens,
"total_tokens": u.total_tokens,
}),
);
}
Value::Object(obj)
}
fn empty_message_item(item_id: &str) -> Value {
json!({
"type": "message",
"id": item_id,
"role": "assistant",
"content": [],
"status": "in_progress",
})
}
fn finish_to_status(reason: FinishReason) -> &'static str {
match reason {
FinishReason::Stop | FinishReason::ToolCalls => "completed",
FinishReason::Length => "incomplete",
}
}
// ── Non-streaming helpers ────────────────────────────────────────────
/// Collect a chat-completions response into a non-streaming
/// [`ResponsesResponse`]. Used by the `/v1/responses` handler when
/// the request doesn't set `stream: true`.
pub fn build_response(
meta: &ResponseMeta,
full_text: String,
reason: FinishReason,
usage: Option<ResponsesUsage>,
) -> ResponsesResponse {
let status = finish_to_status(reason).to_string();
ResponsesResponse {
id: meta.response_id.clone(),
object: "response".into(),
created_at: meta.created_at,
status: status.clone(),
model: meta.model_id.clone(),
output: vec![ResponsesOutputItem::Message {
id: meta.message_item_id.clone(),
role: "assistant".into(),
content: vec![ResponsesOutputContent::OutputText {
text: full_text,
annotations: vec![],
}],
status,
}],
usage,
}
}
#[cfg(test)]
mod tests {
use super::*;
use cortex_core::openai::MessageContent;
fn meta() -> ResponseMeta {
ResponseMeta {
response_id: "resp_1".into(),
created_at: 1700,
model_id: "m".into(),
message_item_id: "msg_1".into(),
}
}
// ── request translator ──────────────────────────────────────────
#[test]
fn translates_text_input_to_single_user_message() {
let req = ResponsesRequest {
model: "m".into(),
input: ResponsesInput::Text("hi".into()),
instructions: None,
stream: false,
max_output_tokens: None,
temperature: None,
top_p: None,
previous_response_id: None,
extra: Value::Object(Default::default()),
};
let chat = request_to_chat(req).unwrap();
assert_eq!(chat.messages.len(), 1);
assert_eq!(chat.messages[0].role, "user");
assert!(matches!(
&chat.messages[0].content,
MessageContent::Text(t) if t == "hi"
));
}
#[test]
fn instructions_become_leading_system_message() {
let req = ResponsesRequest {
model: "m".into(),
input: ResponsesInput::Text("hi".into()),
instructions: Some("you are helpful".into()),
stream: false,
max_output_tokens: None,
temperature: None,
top_p: None,
previous_response_id: None,
extra: Value::Object(Default::default()),
};
let chat = request_to_chat(req).unwrap();
assert_eq!(chat.messages.len(), 2);
assert_eq!(chat.messages[0].role, "system");
assert!(matches!(
&chat.messages[0].content,
MessageContent::Text(t) if t == "you are helpful"
));
assert_eq!(chat.messages[1].role, "user");
}
#[test]
fn rejects_previous_response_id() {
let req = ResponsesRequest {
model: "m".into(),
input: ResponsesInput::Text("hi".into()),
instructions: None,
stream: false,
max_output_tokens: None,
temperature: None,
top_p: None,
previous_response_id: Some("resp_prev".into()),
extra: Value::Object(Default::default()),
};
assert!(matches!(
request_to_chat(req),
Err(TranslateError::ChainedConversationNotSupported)
));
}
#[test]
fn translates_input_items_to_chat_messages() {
let req = ResponsesRequest {
model: "m".into(),
input: ResponsesInput::Items(vec![
ResponsesInputItem::Message {
role: "user".into(),
content: ResponsesMessageContent::Text("first".into()),
},
ResponsesInputItem::Message {
role: "assistant".into(),
content: ResponsesMessageContent::Text("reply".into()),
},
ResponsesInputItem::Message {
role: "user".into(),
content: ResponsesMessageContent::Text("second".into()),
},
]),
instructions: None,
stream: false,
max_output_tokens: None,
temperature: None,
top_p: None,
previous_response_id: None,
extra: Value::Object(Default::default()),
};
let chat = request_to_chat(req).unwrap();
assert_eq!(chat.messages.len(), 3);
let roles: Vec<&str> = chat.messages.iter().map(|m| m.role.as_str()).collect();
assert_eq!(roles, vec!["user", "assistant", "user"]);
}
#[test]
fn image_input_translates_to_chat_parts_array() {
let req = ResponsesRequest {
model: "m".into(),
input: ResponsesInput::Items(vec![ResponsesInputItem::Message {
role: "user".into(),
content: ResponsesMessageContent::Parts(vec![
ResponsesContentPart::InputText {
text: "what is this?".into(),
},
ResponsesContentPart::InputImage {
image_url: "data:image/png;base64,AAA=".into(),
detail: None,
},
]),
}]),
instructions: None,
stream: false,
max_output_tokens: None,
temperature: None,
top_p: None,
previous_response_id: None,
extra: Value::Object(Default::default()),
};
let chat = request_to_chat(req).unwrap();
let parts = match &chat.messages[0].content {
MessageContent::Parts(p) => p.clone(),
other => panic!("expected Parts, got {other:?}"),
};
assert_eq!(parts.len(), 2);
assert_eq!(parts[0]["type"], "text");
assert_eq!(parts[1]["type"], "image_url");
assert_eq!(parts[1]["image_url"]["url"], "data:image/png;base64,AAA=");
}
#[test]
fn text_only_parts_collapse_to_string() {
let req = ResponsesRequest {
model: "m".into(),
input: ResponsesInput::Items(vec![ResponsesInputItem::Message {
role: "user".into(),
content: ResponsesMessageContent::Parts(vec![
ResponsesContentPart::InputText {
text: "first".into(),
},
ResponsesContentPart::InputText {
text: "second".into(),
},
]),
}]),
instructions: None,
stream: false,
max_output_tokens: None,
temperature: None,
top_p: None,
previous_response_id: None,
extra: Value::Object(Default::default()),
};
let chat = request_to_chat(req).unwrap();
assert!(matches!(
&chat.messages[0].content,
MessageContent::Text(t) if t == "first\n\nsecond"
));
}
#[test]
fn reasoning_items_are_silently_dropped() {
let req = ResponsesRequest {
model: "m".into(),
input: ResponsesInput::Items(vec![
ResponsesInputItem::Reasoning { content: vec![] },
ResponsesInputItem::Message {
role: "user".into(),
content: ResponsesMessageContent::Text("hi".into()),
},
]),
instructions: None,
stream: false,
max_output_tokens: None,
temperature: None,
top_p: None,
previous_response_id: None,
extra: Value::Object(Default::default()),
};
let chat = request_to_chat(req).unwrap();
assert_eq!(chat.messages.len(), 1);
assert_eq!(chat.messages[0].role, "user");
}
// ── streaming projector ─────────────────────────────────────────
async fn collect(mut rx: mpsc::Receiver<ResponseStreamFrame>) -> Vec<ResponseStreamFrame> {
let mut out = Vec::new();
while let Some(f) = rx.recv().await {
out.push(f);
}
out
}
#[tokio::test]
async fn full_stream_emits_expected_event_sequence() {
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
let out = project_responses_stream(rx, meta());
tx.send(InferenceEvent::Start).await.unwrap();
tx.send(InferenceEvent::TextDelta("hel".into()))
.await
.unwrap();
tx.send(InferenceEvent::TextDelta("lo".into()))
.await
.unwrap();
tx.send(InferenceEvent::Finish {
reason: FinishReason::Stop,
})
.await
.unwrap();
drop(tx);
let frames = collect(out).await;
let names: Vec<&str> = frames.iter().map(|f| f.event_name).collect();
assert_eq!(
names,
vec![
events::CREATED,
events::IN_PROGRESS,
events::OUTPUT_ITEM_ADDED,
events::CONTENT_PART_ADDED,
events::OUTPUT_TEXT_DELTA,
events::OUTPUT_TEXT_DELTA,
events::OUTPUT_TEXT_DONE,
events::CONTENT_PART_DONE,
events::OUTPUT_ITEM_DONE,
events::COMPLETED,
]
);
// The two deltas should carry the right text. Indices
// shifted by one after IN_PROGRESS inserted between
// CREATED and OUTPUT_ITEM_ADDED.
assert_eq!(frames[4].data["delta"], "hel");
assert_eq!(frames[5].data["delta"], "lo");
// The done event has the full accumulated text.
assert_eq!(frames[6].data["text"], "hello");
// Completed event carries the full message item.
let completed = &frames[9].data["response"];
assert_eq!(completed["status"], "completed");
let output = completed["output"].as_array().unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0]["content"][0]["text"], "hello");
}
#[tokio::test]
async fn length_finish_maps_to_incomplete_status() {
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
let out = project_responses_stream(rx, meta());
tx.send(InferenceEvent::Start).await.unwrap();
tx.send(InferenceEvent::Finish {
reason: FinishReason::Length,
})
.await
.unwrap();
drop(tx);
let frames = collect(out).await;
let completed = frames
.iter()
.find(|f| f.event_name == events::COMPLETED)
.unwrap();
assert_eq!(completed.data["response"]["status"], "incomplete");
}
#[tokio::test]
async fn synthesises_start_frames_when_producer_skips_start() {
// A producer that drops without sending Start (poisoned
// model, immediate disconnect, …) should still produce a
// coherent stream — the projector synthesises the
// mandatory header frames before COMPLETED so the
// consumer never sees an output_text.done without a
// matching content_part.added.
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
let out = project_responses_stream(rx, meta());
drop(tx);
let frames = collect(out).await;
let names: Vec<&str> = frames.iter().map(|f| f.event_name).collect();
assert!(names.contains(&events::CREATED));
assert!(names.contains(&events::COMPLETED));
assert!(names.contains(&events::OUTPUT_TEXT_DONE));
}
#[tokio::test]
async fn empty_text_deltas_are_dropped() {
let (tx, rx) = mpsc::channel::<InferenceEvent>(8);
let out = project_responses_stream(rx, meta());
tx.send(InferenceEvent::Start).await.unwrap();
tx.send(InferenceEvent::TextDelta(String::new()))
.await
.unwrap();
tx.send(InferenceEvent::TextDelta("real".into()))
.await
.unwrap();
tx.send(InferenceEvent::Finish {
reason: FinishReason::Stop,
})
.await
.unwrap();
drop(tx);
let frames = collect(out).await;
let delta_count = frames
.iter()
.filter(|f| f.event_name == events::OUTPUT_TEXT_DELTA)
.count();
assert_eq!(delta_count, 1, "empty delta must not produce a frame");
}
// ── non-streaming builder ───────────────────────────────────────
#[test]
fn build_response_produces_completed_message_with_usage() {
let r = build_response(
&meta(),
"hello".into(),
FinishReason::Stop,
Some(ResponsesUsage {
input_tokens: 5,
output_tokens: 1,
total_tokens: 6,
}),
);
assert_eq!(r.status, "completed");
match &r.output[0] {
ResponsesOutputItem::Message {
role,
content,
status,
..
} => {
assert_eq!(role, "assistant");
assert_eq!(status, "completed");
match &content[0] {
ResponsesOutputContent::OutputText { text, .. } => {
assert_eq!(text, "hello");
}
}
}
other => panic!("expected Message, got {other:?}"),
}
let u = r.usage.unwrap();
assert_eq!(u.total_tokens, 6);
}
#[test]
fn build_response_length_yields_incomplete_status() {
let r = build_response(&meta(), "trunc".into(), FinishReason::Length, None);
assert_eq!(r.status, "incomplete");
}
}

View File

@@ -1,77 +0,0 @@
//! Activation-time behaviour: load_default_models continues past
//! individual failures so a single broken catalogue entry doesn't
//! prevent the rest of the fleet from starting.
use cortex_core::discovery::ActivationState;
use cortex_core::harness::{HarnessConfig, ModelSpec};
use neuron::activation::ActivationTracker;
use neuron::config::HarnessSettings;
use neuron::harness::HarnessRegistry;
use neuron::startup;
#[tokio::test]
async fn test_load_default_models_skips_unknown_harness() {
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
// Both entries fail synchronously inside the registry — no network
// call escapes (the harness lookup mismatches before hf-hub is
// touched). The function should still return cleanly.
let specs = vec![
ModelSpec {
model_id: "model-a".into(),
harness: "no-such-harness".into(),
quant: None,
tensor_parallel: None,
devices: None,
},
ModelSpec {
model_id: "model-b".into(),
harness: "no-such-harness".into(),
quant: None,
tensor_parallel: None,
devices: None,
},
];
let activation = ActivationTracker::new(&specs);
startup::load_default_models(&registry, &specs, &activation).await;
let listed = registry
.list_all_models()
.await
.expect("list_all_models should succeed");
assert!(
listed.is_empty(),
"no models should be loaded after failed entries"
);
// Both specs should land in `failed`; tracker should flip to ready.
let snapshot = activation.snapshot().await;
assert_eq!(snapshot.state, ActivationState::Ready);
assert!(snapshot.pending.is_empty());
assert!(snapshot.in_progress.is_none());
assert!(snapshot.completed.is_empty());
assert_eq!(snapshot.failed.len(), 2);
let failed_ids: Vec<&str> = snapshot
.failed
.iter()
.map(|f| f.model_id.as_str())
.collect();
assert!(failed_ids.contains(&"model-a"));
assert!(failed_ids.contains(&"model-b"));
}
#[tokio::test]
async fn test_load_default_models_empty_is_noop() {
let registry = HarnessRegistry::new();
let activation = ActivationTracker::new(&[]);
startup::load_default_models(&registry, &[], &activation).await;
let snapshot = activation.snapshot().await;
assert_eq!(snapshot.state, ActivationState::Ready);
}

View File

@@ -1,5 +1,4 @@
use cortex_core::discovery::{DeviceInfo, DiscoveryResponse};
use neuron::activation::ActivationTracker;
use neuron::api::{self, NeuronState};
use neuron::harness::HarnessRegistry;
use neuron::health::HealthCache;
@@ -15,8 +14,6 @@ async fn spawn_neuron(discovery: DiscoveryResponse) -> String {
discovery,
health_cache,
registry: RwLock::new(registry),
candle: None,
activation: Arc::new(ActivationTracker::new(&[])),
});
let app = api::neuron_routes().with_state(state);
@@ -138,31 +135,56 @@ async fn test_models_empty_registry() {
assert!(body.as_array().unwrap().is_empty());
}
/// Verify the candle harness registers, list is empty by default, and a
/// load attempt for an obviously-bogus model id returns a 4xx error
/// without crashing the daemon. Real load/unload exercising actual GGUF
/// download is covered by `tests/candle_lifecycle.rs` (cuda-integration).
/// Spawn a mock mistral.rs backend and a neuron with the mistralrs harness
/// pointing at it, then test the full model lifecycle through neuron's API.
#[tokio::test]
async fn test_candle_harness_registers_and_rejects_bogus_model() {
async fn test_models_via_mistralrs_harness() {
use axum::routing::{get, post};
use axum::{Json, Router};
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
use serde_json::Value;
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:13131",
&HarnessSettings::default(),
);
// Mock mistral.rs backend.
let mock_app = Router::new()
.route(
"/v1/models",
get(|| async {
Json(json!({
"data": [
{"id": "test-model", "status": "loaded"},
{"id": "other-model", "status": "unloaded"}
]
}))
}),
)
.route(
"/v1/models/unload",
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
)
.route(
"/v1/models/reload",
post(|Json(_body): Json<Value>| async { Json(json!({"status": "ok"})) }),
);
let mock_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let mock_addr = mock_listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(mock_listener, mock_app).await.unwrap();
});
let mock_url = format!("http://{mock_addr}");
// Build neuron with mistralrs harness pointing at mock.
let registry = HarnessRegistry::from_configs(&[HarnessConfig {
name: "mistralrs".into(),
endpoint: Some(mock_url.clone()),
systemd_unit: None,
}]);
let candle = registry.candle();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle,
activation: Arc::new(ActivationTracker::new(&[])),
});
let app = api::neuron_routes().with_state(state);
@@ -175,6 +197,7 @@ async fn test_candle_harness_registers_and_rejects_bogus_model() {
let client = reqwest::Client::new();
// GET /models — should return models from mock mistralrs.
let resp = client
.get(format!("{neuron_url}/models"))
.send()
@@ -182,308 +205,45 @@ async fn test_candle_harness_registers_and_rejects_bogus_model() {
.unwrap();
assert_eq!(resp.status(), 200);
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
assert!(models.is_empty());
assert_eq!(models.len(), 2);
assert_eq!(models[0]["id"], "test-model");
assert_eq!(models[0]["harness"], "mistralrs");
assert_eq!(models[0]["status"], "loaded");
assert_eq!(models[1]["id"], "other-model");
assert_eq!(models[1]["status"], "unloaded");
// Sending a wrong-harness spec should be rejected synchronously
// without touching the network or the model registry.
// GET /models/test-model/endpoint — should return mock URL.
let resp = client
.get(format!("{neuron_url}/models/test-model/endpoint"))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200);
let body: serde_json::Value = resp.json().await.unwrap();
assert_eq!(body["url"], mock_url);
// POST /models/unload — should succeed.
let resp = client
.post(format!("{neuron_url}/models/unload"))
.json(&json!({"model_id": "test-model"}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200);
let body: serde_json::Value = resp.json().await.unwrap();
assert_eq!(body["status"], "unloaded");
// POST /models/load — should succeed.
let resp = client
.post(format!("{neuron_url}/models/load"))
.json(&json!({"model_id": "definitely/not-real", "harness": "not-candle"}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 400);
// Registry still empty.
let resp = client
.get(format!("{neuron_url}/models"))
.send()
.await
.unwrap();
let models: Vec<serde_json::Value> = resp.json().await.unwrap();
assert!(models.is_empty());
}
/// `/v1/chat/completions` returns 503 when no candle harness is registered.
#[tokio::test]
async fn test_chat_completions_no_candle_harness() {
let registry = HarnessRegistry::new();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle: None,
activation: Arc::new(ActivationTracker::new(&[])),
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/chat/completions"))
.json(&json!({
"model": "anything",
"messages": [{"role": "user", "content": "hi"}]
"model_id": "test-model",
"harness": "mistralrs"
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 503);
}
/// `/v1/chat/completions` returns 404 when the requested model isn't loaded.
#[tokio::test]
async fn test_chat_completions_model_not_loaded() {
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
let candle = registry.candle();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle,
activation: Arc::new(ActivationTracker::new(&[])),
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/chat/completions"))
.json(&json!({
"model": "definitely/not-loaded",
"messages": [{"role": "user", "content": "hi"}]
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 404);
}
/// `/v1/chat/completions` with `stream: true` returns 404 when the
/// model isn't loaded — same surface as the non-streaming path. The
/// streaming code only kicks in once the model lookup succeeds.
#[tokio::test]
async fn test_chat_completions_streaming_model_not_loaded() {
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
let candle = registry.candle();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle,
activation: Arc::new(ActivationTracker::new(&[])),
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/chat/completions"))
.json(&json!({
"model": "definitely/not-loaded",
"messages": [{"role": "user", "content": "hi"}],
"stream": true
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 404);
}
// ── /v1/responses ────────────────────────────────────────────────────
/// `/v1/responses` returns 503 when no candle harness is registered —
/// matches the chat-completions error shape so a client can swap
/// endpoints without re-handling 503s.
#[tokio::test]
async fn test_responses_no_candle_harness() {
let registry = HarnessRegistry::new();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle: None,
activation: Arc::new(ActivationTracker::new(&[])),
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/responses"))
.json(&json!({"model": "anything", "input": "hi"}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 503);
}
/// `previous_response_id` is rejected at translate time with 400 —
/// we don't store responses server-side yet, so chained
/// conversations can't be honoured.
#[tokio::test]
async fn test_responses_rejects_previous_response_id() {
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
let candle = registry.candle();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle,
activation: Arc::new(ActivationTracker::new(&[])),
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/responses"))
.json(&json!({
"model": "anything",
"input": "hi",
"previous_response_id": "resp_prev_42"
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 400);
assert_eq!(resp.status(), 200);
let body: serde_json::Value = resp.json().await.unwrap();
assert_eq!(body["code"], "chained_conversation_not_supported");
}
/// `/v1/responses` returns 404 when the model isn't loaded — same
/// surface as chat completions.
#[tokio::test]
async fn test_responses_model_not_loaded() {
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
let candle = registry.candle();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle,
activation: Arc::new(ActivationTracker::new(&[])),
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/responses"))
.json(&json!({"model": "not-loaded", "input": "hi"}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 404);
}
/// Same model-not-loaded surface on the streaming path. The
/// stream is opened only after model lookup succeeds, so a
/// missing model fails fast with a non-SSE 404 response.
#[tokio::test]
async fn test_responses_streaming_model_not_loaded() {
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
let candle = registry.candle();
let health_cache = Arc::new(HealthCache::new());
let state = Arc::new(NeuronState {
discovery: fake_discovery(),
health_cache,
registry: RwLock::new(registry),
candle,
activation: Arc::new(ActivationTracker::new(&[])),
});
let app = api::neuron_routes().with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let url = format!("http://{addr}");
let resp = reqwest::Client::new()
.post(format!("{url}/v1/responses"))
.json(&json!({
"model": "not-loaded",
"input": "hi",
"stream": true
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 404);
assert_eq!(body["status"], "loaded");
}

View File

@@ -1,87 +0,0 @@
//! Real model load/unload lifecycle through the candle harness.
//!
//! Gated behind the `cuda-integration` feature because it downloads a
//! real (small) GGUF from HuggingFace and materialises tensors on the
//! configured device. Run on a host with network access and either a
//! CUDA GPU (when built with `--features cuda`) or enough CPU RAM to
//! hold the model.
//!
//! Usage:
//! cargo test -p neuron --features cuda-integration --test candle_lifecycle
//!
//! Optional environment variables:
//! NEURON_TEST_MODEL_ID — HuggingFace repo to load (default: a small
//! public Qwen3 GGUF repo).
//! NEURON_TEST_QUANT — quant substring matched against GGUF
//! filenames (default: "Q4_K_M").
//! HF_HOME — HuggingFace cache directory.
#![cfg(feature = "cuda-integration")]
use cortex_core::harness::{HarnessConfig, ModelSpec};
use neuron::config::HarnessSettings;
use neuron::harness::HarnessRegistry;
use std::path::PathBuf;
#[tokio::test]
async fn test_candle_qwen3_load_unload_lifecycle() {
let _ = tracing_subscriber::fmt()
.with_test_writer()
.with_env_filter("info,neuron=debug")
.try_init();
let model_id = std::env::var("NEURON_TEST_MODEL_ID")
.unwrap_or_else(|_| "Qwen/Qwen3-0.6B-GGUF".to_string());
let quant = std::env::var("NEURON_TEST_QUANT").unwrap_or_else(|_| "Q4_K_M".to_string());
let mut settings = HarnessSettings::default();
if let Ok(home) = std::env::var("HF_HOME") {
settings.candle.hf_cache = Some(PathBuf::from(home));
}
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:13131",
&settings,
);
let spec = ModelSpec {
model_id: model_id.clone(),
harness: "candle".into(),
quant: Some(quant),
tensor_parallel: None,
devices: Some(vec![0]),
};
registry
.load_model(&spec)
.await
.expect("load_model should succeed");
let models = registry.list_all_models().await.expect("list_all_models");
assert_eq!(models.len(), 1, "expected exactly one loaded model");
assert_eq!(models[0].id, model_id);
assert_eq!(models[0].harness, "candle");
assert_eq!(models[0].status, "loaded");
let url = registry.inference_endpoint(&model_id).await;
assert_eq!(url, Some("http://localhost:13131".into()));
// Re-loading the same model should be rejected.
let again = registry.load_model(&spec).await;
assert!(again.is_err(), "second load should error");
registry
.unload_model(&model_id)
.await
.expect("unload_model should succeed");
let models = registry.list_all_models().await.expect("list_all_models");
assert!(models.is_empty(), "registry should be empty after unload");
// Unloading a model that isn't loaded should error.
let err = registry.unload_model(&model_id).await;
assert!(err.is_err(), "unload of missing model should error");
}

View File

@@ -1,269 +0,0 @@
//! End-to-end preflight tests against a mock HF-compatible server.
//!
//! Unit tests in `harness/preflight.rs` exercise the classifier and
//! feasibility table against synthetic file lists. These tests close
//! the loop: spawn an axum server that returns a `RepoInfo`-shaped
//! JSON payload at `/api/models/{org}/{name}`, point `hf_hub::Api` at
//! it, and assert `preflight()` returns the expected outcome.
use axum::Router;
use axum::extract::Path;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json};
use axum::routing::get;
use cortex_core::harness::ModelSpec;
use neuron::harness::preflight::{PreflightError, SourceFormat, preflight};
use serde_json::{Value, json};
use std::sync::Arc;
use std::sync::Mutex;
/// Per-test mock state: a map from `{org}/{name}` to the JSON body the
/// mock server returns at the corresponding `/api/models/{org}/{name}`
/// endpoint. `None` means "respond 404".
type MockBodies = Arc<Mutex<std::collections::HashMap<String, Option<Value>>>>;
async fn spawn_mock(bodies: MockBodies) -> String {
// hf-hub 0.4 calls /api/models/{org}/{name}/revision/main for
// `repo.info()`. We route both shapes so the test stays robust
// to a future hf-hub upgrade that drops the `/revision/main`
// suffix.
let app = Router::new()
.route("/api/models/{org}/{name}", get(model_info))
.route(
"/api/models/{org}/{name}/revision/{rev}",
get(model_info_rev),
)
.with_state(bodies);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
format!("http://{addr}")
}
async fn model_info(
Path((org, name)): Path<(String, String)>,
axum::extract::State(bodies): axum::extract::State<MockBodies>,
) -> impl IntoResponse {
respond(&format!("{org}/{name}"), &bodies)
}
async fn model_info_rev(
Path((org, name, _rev)): Path<(String, String, String)>,
axum::extract::State(bodies): axum::extract::State<MockBodies>,
) -> impl IntoResponse {
respond(&format!("{org}/{name}"), &bodies)
}
fn respond(key: &str, bodies: &MockBodies) -> axum::response::Response {
let entry = bodies.lock().unwrap().get(key).cloned();
match entry {
Some(Some(body)) => Json(body).into_response(),
Some(None) | None => (StatusCode::NOT_FOUND, "not found").into_response(),
}
}
fn build_api(endpoint: &str, cache_dir: &std::path::Path) -> hf_hub::api::tokio::Api {
hf_hub::api::tokio::ApiBuilder::new()
.with_endpoint(endpoint.to_string())
.with_cache_dir(cache_dir.to_path_buf())
.build()
.expect("build hf-hub Api")
}
fn siblings(filenames: &[&str]) -> Value {
json!({
"sha": "0000000000000000000000000000000000000000",
"siblings": filenames.iter().map(|f| json!({ "rfilename": f })).collect::<Vec<_>>(),
})
}
fn spec(model_id: &str, tp: Option<u32>, quant: Option<&str>) -> ModelSpec {
ModelSpec {
model_id: model_id.into(),
harness: "candle".into(),
quant: quant.map(String::from),
tensor_parallel: tp,
devices: None,
}
}
#[tokio::test]
async fn preflight_gguf_tp_rejected_over_http() {
let cache = tempfile::tempdir().expect("tempdir");
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
bodies.lock().unwrap().insert(
"HauhauCS/Qwen3.6".to_string(),
Some(siblings(&[
"README.md",
".gitattributes",
"Qwen3.6-Q4_K_P.gguf",
"Qwen3.6-Q6_K_P.gguf",
"Qwen3.6-Q8_K_P.gguf",
])),
);
let endpoint = spawn_mock(bodies).await;
let api = build_api(&endpoint, cache.path());
let s = spec("HauhauCS/Qwen3.6", Some(2), Some("q6k"));
let err = preflight(&api, &s).await.unwrap_err();
match err {
PreflightError::TpRequiresSafetensors {
model_id,
tp_size,
gguf_quants,
..
} => {
assert_eq!(model_id, "HauhauCS/Qwen3.6");
assert_eq!(tp_size, 2);
assert_eq!(gguf_quants.len(), 3);
}
other => panic!("expected TpRequiresSafetensors, got {other:?}"),
}
}
#[tokio::test]
async fn preflight_gguf_quant_suggestion_over_http() {
let cache = tempfile::tempdir().expect("tempdir");
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
bodies.lock().unwrap().insert(
"HauhauCS/Qwen3.6".to_string(),
Some(siblings(&[
"Qwen3.6-Q4_K_P.gguf",
"Qwen3.6-Q5_K_P.gguf",
"Qwen3.6-Q6_K_P.gguf",
"Qwen3.6-Q8_K_P.gguf",
])),
);
let endpoint = spawn_mock(bodies).await;
let api = build_api(&endpoint, cache.path());
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6k"));
let err = preflight(&api, &s).await.unwrap_err();
match err {
PreflightError::QuantNotFound {
requested,
nearest,
available,
..
} => {
assert_eq!(requested, "q6k");
assert_eq!(nearest.as_deref(), Some("q6_k_p"));
assert_eq!(available.len(), 4);
}
other => panic!("expected QuantNotFound, got {other:?}"),
}
}
#[tokio::test]
async fn preflight_dense_safetensors_tp_ok() {
let cache = tempfile::tempdir().expect("tempdir");
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
bodies.lock().unwrap().insert(
"Qwen/Q3-30B".to_string(),
Some(siblings(&[
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"model.safetensors.index.json",
"model-00001-of-00006.safetensors",
"model-00002-of-00006.safetensors",
"model-00003-of-00006.safetensors",
])),
);
let endpoint = spawn_mock(bodies).await;
let api = build_api(&endpoint, cache.path());
let s = spec("Qwen/Q3-30B", Some(2), Some("q5k"));
let plan = preflight(&api, &s).await.expect("dense+tp should succeed");
assert_eq!(plan.tp_size, 2);
assert!(plan.picked_quant_file.is_none());
assert!(matches!(
plan.format,
SourceFormat::DenseSafetensors { sharded: true }
));
}
#[tokio::test]
async fn preflight_gguf_single_gpu_good_quant() {
let cache = tempfile::tempdir().expect("tempdir");
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
bodies.lock().unwrap().insert(
"HauhauCS/Qwen3.6".to_string(),
Some(siblings(&["Qwen3.6-Q4_K_P.gguf", "Qwen3.6-Q6_K_P.gguf"])),
);
let endpoint = spawn_mock(bodies).await;
let api = build_api(&endpoint, cache.path());
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6_k_p"));
let plan = preflight(&api, &s)
.await
.expect("good quant should succeed");
assert_eq!(plan.tp_size, 1);
assert_eq!(
plan.picked_quant_file.as_deref(),
Some("Qwen3.6-Q6_K_P.gguf")
);
}
#[tokio::test]
async fn preflight_repo_fetch_failed_on_404() {
// Mock server has no entry for this id → 404, exercising the
// RepoFetchFailed path (the same shape today's HauhauCS scenario
// would have produced if we'd added preflight before the cache
// download was attempted).
let cache = tempfile::tempdir().expect("tempdir");
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
let endpoint = spawn_mock(bodies).await;
let api = build_api(&endpoint, cache.path());
let s = spec("DoesNot/Exist", Some(1), None);
let err = preflight(&api, &s).await.unwrap_err();
assert!(
matches!(err, PreflightError::RepoFetchFailed { .. }),
"expected RepoFetchFailed, got {err:?}"
);
}
#[tokio::test]
async fn preflight_empty_repo_rejected() {
let cache = tempfile::tempdir().expect("tempdir");
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
bodies.lock().unwrap().insert(
"Empty/Repo".to_string(),
Some(siblings(&["README.md", "tokenizer.json"])),
);
let endpoint = spawn_mock(bodies).await;
let api = build_api(&endpoint, cache.path());
let s = spec("Empty/Repo", Some(1), None);
let err = preflight(&api, &s).await.unwrap_err();
assert!(
matches!(err, PreflightError::EmptyRepo { .. }),
"expected EmptyRepo, got {err:?}"
);
}
#[tokio::test]
async fn preflight_mixed_repo_prefers_safetensors() {
let cache = tempfile::tempdir().expect("tempdir");
let bodies: MockBodies = Arc::new(Mutex::new(Default::default()));
bodies.lock().unwrap().insert(
"Mixed/Repo".to_string(),
Some(siblings(&[
"config.json",
"tokenizer.json",
"model.safetensors",
"model-Q4_K_M.gguf",
])),
);
let endpoint = spawn_mock(bodies).await;
let api = build_api(&endpoint, cache.path());
// TP=2 + quant should succeed via the dense path even though a
// GGUF is present — the dense path handles ISQ.
let s = spec("Mixed/Repo", Some(2), Some("q5k"));
let plan = preflight(&api, &s).await.expect("mixed should succeed");
assert!(matches!(plan.format, SourceFormat::Mixed { .. }));
}

View File

@@ -1,32 +0,0 @@
//! Deactivation behaviour: unload_all_models tolerates an empty
//! registry and continues past per-model unload failures.
use cortex_core::harness::HarnessConfig;
use neuron::config::HarnessSettings;
use neuron::harness::HarnessRegistry;
use neuron::startup;
#[tokio::test]
async fn test_unload_all_models_empty_registry_is_noop() {
let registry = HarnessRegistry::new();
startup::unload_all_models(&registry).await;
}
#[tokio::test]
async fn test_unload_all_models_with_no_loaded_models() {
let registry = HarnessRegistry::from_configs(
&[HarnessConfig {
name: "candle".into(),
}],
"http://localhost:0",
&HarnessSettings::default(),
);
startup::unload_all_models(&registry).await;
let listed = registry
.list_all_models()
.await
.expect("list_all_models should still succeed after shutdown cleanup");
assert!(listed.is_empty());
}

View File

@@ -1,148 +0,0 @@
//! Stage 7a-i: confirm the TP worker subprocess lifecycle round-trips.
//!
//! Spawns two worker subprocesses via the leader→worker stdio RPC,
//! pings each, and cleanly shuts them down. No CUDA required —
//! `Init` and `NcclSanityCheck` are stubbed in 7a-i, so this test
//! runs on any host the workspace builds on.
use neuron::harness::device_worker::DeviceWorkerHandle;
use neuron::harness::tp::{WorkerPool, rpc::WorkerResponse};
/// Path to the neuron binary built by cargo for this test process.
/// cargo populates `CARGO_BIN_EXE_neuron` at compile time for sibling-
/// binary tests; production paths in main.rs use `/proc/self/exe`.
const NEURON_BIN: &str = env!("CARGO_BIN_EXE_neuron");
/// Two workers (so we spawn one subprocess: rank 0 is in-process,
/// rank 1 is the child). Verify the spawned worker responds to Ping
/// with its own identity, then shut it down cleanly.
#[tokio::test]
async fn test_spawn_ping_shutdown() {
// cuda_devices: rank 0 → device 0 (leader, unused here),
// rank 1 → device 1 (worker; not actually opened in 7a-i).
let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker");
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 2, &[0, 1], leader_worker)
.await
.expect("spawn worker pool");
let pongs = pool.ping_all().await.expect("ping all workers");
assert_eq!(pongs.len(), 1, "expected one Pong (rank 1 only)");
match &pongs[0] {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
assert_eq!(*rank, 1);
assert_eq!(*world_size, 2);
assert_eq!(*cuda_device, 1);
}
other => panic!("expected Pong, got {other:?}"),
}
pool.shutdown().await.expect("clean shutdown");
}
/// Three workers — exercise the loop in `ping_all` / `shutdown`.
#[tokio::test]
async fn test_spawn_three_workers() {
let leader_worker = DeviceWorkerHandle::spawn(0).expect("spawn device worker");
let mut pool = WorkerPool::spawn(NEURON_BIN.as_ref(), 3, &[0, 1, 2], leader_worker)
.await
.expect("spawn worker pool");
let pongs = pool.ping_all().await.expect("ping all workers");
assert_eq!(pongs.len(), 2, "expected two Pongs (ranks 1 and 2)");
for (i, resp) in pongs.iter().enumerate() {
match resp {
WorkerResponse::Pong {
rank,
world_size,
cuda_device,
} => {
let expected_rank = (i + 1) as u32;
assert_eq!(*rank, expected_rank);
assert_eq!(*world_size, 3);
assert_eq!(*cuda_device, expected_rank);
}
other => panic!("expected Pong, got {other:?}"),
}
}
pool.shutdown().await.expect("clean shutdown");
}
/// 7a-ii: without the cuda feature, Init must fail with a clear
/// `cuda_feature_not_enabled` marker rather than silently succeeding.
/// This is the local-dev-box test; the real NCCL handshake is exercised
/// by `tp_worker_lifecycle_cuda.rs` (gated on `cuda-integration`).
#[tokio::test]
async fn test_init_returns_cuda_feature_not_enabled_without_cuda() {
use neuron::harness::tp::rpc::WorkerRequest;
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
// Spawn a single worker by hand to send Init directly (the pool's
// public API doesn't expose Init yet — that lands in 7a-ii).
let mut child = Command::new(NEURON_BIN)
.arg("--worker")
.arg("--rank")
.arg("1")
.arg("--tp-size")
.arg("2")
.arg("--cuda-device")
.arg("1")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true)
.spawn()
.expect("spawn worker");
let mut stdin = child.stdin.take().expect("stdin");
let stdout = child.stdout.take().expect("stdout");
let mut lines = BufReader::new(stdout).lines();
let req = WorkerRequest::Init {
comm_id: "ff".repeat(128),
};
let mut payload = serde_json::to_string(&req).unwrap();
payload.push('\n');
stdin.write_all(payload.as_bytes()).await.unwrap();
stdin.flush().await.unwrap();
let reply = lines
.next_line()
.await
.expect("read line")
.expect("got line");
let resp: WorkerResponse = serde_json::from_str(&reply).expect("parse reply");
match resp {
WorkerResponse::Error { kind, .. } => {
#[cfg(feature = "cuda")]
{
// With cuda enabled the response depends on whether
// CUDA hardware is actually present. Accept either
// the success contract or a real NCCL failure.
let _ = kind;
}
#[cfg(not(feature = "cuda"))]
assert_eq!(kind, "cuda_feature_not_enabled");
}
WorkerResponse::InitOk => {
// Real NCCL succeeded — only possible with cuda feature
// AND a working NCCL stack AND another rank actually
// joining. Don't fail; just acknowledge.
#[cfg(not(feature = "cuda"))]
panic!("InitOk without cuda feature is impossible");
}
other => panic!("expected Error or InitOk, got {other:?}"),
}
// Clean shutdown.
stdin.write_all(b"{\"op\":\"shutdown\"}\n").await.unwrap();
stdin.flush().await.unwrap();
let _ = lines.next_line().await; // Bye
let _ = child.wait().await;
}

Some files were not shown because too many files have changed in this diff Show More