2 Commits

Author SHA1 Message Date
0184ccab28 chore: move default ports out of common-collision ranges
Some checks failed
CI / Build cortex SRPM (push) Has been cancelled
CI / Build neuron SRPM (push) Has been cancelled
CI / Publish cortex to COPR (push) Has been cancelled
CI / Publish neuron to COPR (push) Has been cancelled
CI / Bump version in source (push) Has been cancelled
CI / Format, lint, build, test (push) Has been cancelled
Previous defaults collided with well-trodden infra services and with
the Linux ephemeral port range:

- cortex API     8000 — common dev-server default (Django, minio UI)
- cortex metrics 9100 — Prometheus node_exporter default
- neuron API     9090 — Cockpit default on Fedora, Prometheus self

Move to helexa-themed palindromic ports, all below Linux's
32768-60999 ephemeral range and not registered to any well-known
service:

- cortex API     31313
- cortex metrics 31314
- neuron API     13131

Updated places:
- cortex.example.toml, neuron.example.toml defaults
- default impls in cortex-core and neuron config
- cortex-cli --endpoint default for the status subcommand
- doc comments citing example URLs
- README.md and CLAUDE.md snippets

Consumers already on the old ports need a one-line edit in their
/etc/cortex/cortex.toml or /etc/neuron/neuron.toml to match;
firewall rules and prometheus scrape configs will also need
updating.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 17:35:09 +03:00
471b9b7629 ci: drop actions/cache for cargo registry and target
The cache round-trip (download + unpack) was consistently taking
around 6 minutes, noticeably longer than the ~3 minute cold build
it was meant to accelerate. Net-negative on CI time — remove it.

sccache with the S3 backend still provides dep-level caching at a
much lower overhead, so we keep the majority of the cache benefit
without paying the actions/cache tarball cost.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 16:47:32 +03:00
166 changed files with 642 additions and 51117 deletions

View File

@@ -1,726 +0,0 @@
name: build-prerelease
# Builds CUDA-flavoured neuron binaries (and a single cortex binary),
# packages each as a Fedora RPM, signs them, and publishes to the
# `unstable` channel at rpm.lair.cafe.
#
# Change-aware: the `prepare` job diffs HEAD against the git sha
# embedded in the most recently *published* unstable RPM (per package)
# and skips builds whose inputs didn't change. Docs-only commits build
# nothing; gateway-only commits skip the 3 CUDA builds (and, via
# deploy.yml's own check-update gate, the neuron restarts + model
# cold-loads). Diffing against the published sha — not the previous
# push — means a failed run can never cause a change to be missed.
#
# Lint (fmt+clippy) and test run here as parallel jobs and gate
# `publish`; ci.yml no longer runs on pushes to main (see its trigger
# comment), so the two workflows stop competing for the same runners.
#
# The published packages are versioned as e.g.
# helexa-neuron-blackwell-0.1.16-0.1.20260518T140530.gitabcdef0.fc43.x86_64
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
# commit time (s) commit sha
# so they sort BELOW the eventual 0.1.16-1 stable release, and so two
# commits on the same day are still strictly ordered by their commit
# timestamps (rather than by RPM-vercmp's alpha-vs-digit precedence
# on the SHA fragment).
on:
# Auto-build on every push to main so the unstable channel tracks
# head without a manual dispatch step.
push:
branches: [main]
# Manual dispatch still available to build from a non-main ref.
# Dispatched runs skip change detection and build everything.
workflow_dispatch:
inputs:
ref:
description: "Git ref to build (branch / tag / commit). Defaults to the workflow's branch."
required: false
default: ""
# Coalesce same-ref pushes: a newer push cancels the older in-flight
# run — the newest commit is the one we want on the fleet. The publish
# job keeps its own `rpm-publish` group (cancel=false) so an in-flight
# repo update is never interrupted. Runners are ephemeral (one VM per
# job) so concurrent runs no longer race on a shared workspace; the
# old shared `cortex-runner-pool` group with ci.yml is gone.
concurrency:
group: build-prerelease-${{ github.ref }}
cancel-in-progress: true
env:
CARGO_INCREMENTAL: "0"
CARGO_TERM_COLOR: "always"
jobs:
prepare:
name: Resolve version stamps + change detection
runs-on: rust
outputs:
version: ${{ steps.info.outputs.version }}
release: ${{ steps.info.outputs.release }}
short_sha: ${{ steps.info.outputs.short_sha }}
commit_timestamp: ${{ steps.info.outputs.commit_timestamp }}
build_cortex: ${{ steps.changes.outputs.build_cortex }}
build_neuron: ${{ steps.changes.outputs.build_neuron }}
build_bench: ${{ steps.changes.outputs.build_bench }}
check_rust: ${{ steps.changes.outputs.check_rust }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
fetch-depth: 0
- id: info
run: |
set -eux
VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml)
SHORT_SHA=$(git rev-parse --short=7 HEAD)
# Second-precise commit timestamp gives the release stamp a
# strictly monotonic numeric prefix. The earlier %Y%m%d-only
# form let same-day builds be ordered by RPM's rpmvercmp
# rules over the SHA, which is non-chronological — e.g.
# "git602e8e1" sorts newer than "gitf9f5fa4" purely because
# rpmvercmp ranks digit-prefixed segments above alpha ones.
# The SHA stays only as a debug identifier; sort order is
# decided entirely by the timestamp.
COMMIT_TIMESTAMP=$(git log -1 --format=%cd --date=format:%Y%m%d%H%M%S HEAD)
RELEASE="0.1.${COMMIT_TIMESTAMP}.git${SHORT_SHA}"
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
echo "release=${RELEASE}" >> "$GITHUB_OUTPUT"
echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT"
echo "commit_timestamp=${COMMIT_TIMESTAMP}" >> "$GITHUB_OUTPUT"
- id: changes
run: |
set -ux
# Default: build everything. Detection only ever narrows
# this, and any failure along the way (manifest unreachable,
# unparsable, sha not in history after a force-push) leaves
# the full build in place. Manual dispatches always build
# everything — predictable when building odd refs.
BUILD_CORTEX=true
BUILD_NEURON=true
BUILD_BENCH=true
CHECK_RUST=true
if [ "${GITHUB_EVENT_NAME}" = "push" ]; then
MANIFEST_URL="https://rpm.lair.cafe/fedora/43/x86_64/unstable/packages.json"
if curl -fsS --max-time 20 -o /tmp/packages.json "$MANIFEST_URL"; then
# Latest published sha per package, by buildTime.
base_for() {
python3 - "$1" <<'PY'
import json, re, sys
name = sys.argv[1]
try:
with open("/tmp/packages.json") as f:
pkgs = json.load(f)["packages"]
cands = [p for p in pkgs if p.get("name") == name]
if cands:
latest = max(cands, key=lambda p: p.get("buildTime", 0))
m = re.search(r"git\.?([0-9a-f]{7,40})", latest.get("release", ""))
if m:
print(m.group(1))
except Exception:
pass
PY
}
# true if no usable base, else true iff the diff since
# the published sha touches the given path pattern.
decide() {
local base="$1" pattern="$2"
if [ -z "$base" ] \
|| ! git cat-file -e "${base}^{commit}" 2>/dev/null \
|| ! git merge-base --is-ancestor "$base" HEAD 2>/dev/null; then
echo true; return
fi
if git diff --name-only "${base}..HEAD" | grep -qE "$pattern"; then
echo true
else
echo false
fi
}
# cortex-core is shared by both binaries; Cargo.{toml,lock}
# affect both; this workflow file affects both.
NEURON_RE='^crates/neuron/|^crates/cortex-core/|^Cargo\.toml$|^Cargo\.lock$|^rpm/helexa-neuron-prerelease\.spec$|^data/neuron|^neuron\.example\.toml$|^\.gitea/workflows/build-prerelease\.yml$'
CORTEX_RE='^crates/cortex-gateway/|^crates/cortex-cli/|^crates/cortex-core/|^Cargo\.toml$|^Cargo\.lock$|^rpm/cortex-prerelease\.spec$|^data/cortex|^cortex\.example\.toml$|^models\.example\.toml$|^\.gitea/workflows/build-prerelease\.yml$'
BENCH_RE='^crates/helexa-bench/|^crates/cortex-core/|^Cargo\.toml$|^Cargo\.lock$|^rpm/helexa-bench-prerelease\.spec$|^data/helexa-bench|^helexa-bench\.example\.toml$|^\.gitea/workflows/build-prerelease\.yml$'
# Any Rust change (incl. crates not packaged here, e.g.
# helexa-acp) still needs lint+test on main.
RUST_RE='\.rs$|^crates/|Cargo\.toml$|^Cargo\.lock$'
CORTEX_BASE=$(base_for cortex)
NEURON_BASE=$(base_for helexa-neuron-blackwell)
BENCH_BASE=$(base_for helexa-bench)
BUILD_CORTEX=$(decide "$CORTEX_BASE" "$CORTEX_RE")
BUILD_NEURON=$(decide "$NEURON_BASE" "$NEURON_RE")
BUILD_BENCH=$(decide "$BENCH_BASE" "$BENCH_RE")
if [ "$BUILD_CORTEX" = "true" ] || [ "$BUILD_NEURON" = "true" ] || [ "$BUILD_BENCH" = "true" ]; then
CHECK_RUST=true
else
CHECK_RUST=$(decide "$CORTEX_BASE" "$RUST_RE")
fi
fi
fi
echo "build_cortex=${BUILD_CORTEX}" >> "$GITHUB_OUTPUT"
echo "build_neuron=${BUILD_NEURON}" >> "$GITHUB_OUTPUT"
echo "build_bench=${BUILD_BENCH}" >> "$GITHUB_OUTPUT"
echo "check_rust=${CHECK_RUST}" >> "$GITHUB_OUTPUT"
echo "### change detection: build_cortex=${BUILD_CORTEX} build_neuron=${BUILD_NEURON} build_bench=${BUILD_BENCH} check_rust=${CHECK_RUST}"
# fmt + clippy + test moved here from ci.yml for main pushes so the
# two workflows stop queueing against each other (ci.yml's checks
# used to delay build-cortex by ~12 minutes on the shared runner
# pool). They run in parallel with the builds and gate `publish`,
# not the builds themselves — a clippy warning still can't reach the
# fleet, but it also doesn't serialize the pipeline.
lint:
name: Lint (fmt + clippy)
needs: prepare
if: needs.prepare.outputs.check_rust == 'true'
runs-on: rust
env:
RUSTC_WRAPPER: sccache
SCCACHE_BUCKET: sccache
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
SCCACHE_REGION: auto
SCCACHE_S3_USE_SSL: "false"
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- run: cargo fmt --check --all
# sccache failures come in two modes: transient races (a plain
# retry clears them) and a wedged/dead server, where every
# same-VM retry fails identically (sccache fatal error, ENOENT
# on its own tmp files). Escalate accordingly: retry → restart
# the server → final attempt uncached. A sick cache costs build
# time, never the run.
- name: Clippy (with sccache escalation)
run: |
for attempt in 1 2 3; do
echo "::group::clippy attempt ${attempt}"
if [ "${attempt}" -eq 3 ]; then
echo "final attempt: building without sccache"
export RUSTC_WRAPPER=""
fi
if cargo clippy --workspace -- -D warnings; then
echo "::endgroup::"
exit 0
fi
echo "::endgroup::"
echo "clippy failed on attempt ${attempt}"
if [ "${attempt}" -eq 1 ]; then
sccache --stop-server || true
sccache --start-server || true
fi
sleep 5
done
echo "clippy failed after 3 attempts"
exit 1
- run: sccache --show-stats || true
test:
name: Test
needs: prepare
if: needs.prepare.outputs.check_rust == 'true'
runs-on: rust
env:
RUSTC_WRAPPER: sccache
SCCACHE_BUCKET: sccache
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
SCCACHE_REGION: auto
SCCACHE_S3_USE_SSL: "false"
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
# See the lint job for the escalation rationale.
- name: Test (with sccache escalation)
run: |
for attempt in 1 2 3; do
echo "::group::test attempt ${attempt}"
if [ "${attempt}" -eq 3 ]; then
echo "final attempt: building without sccache"
export RUSTC_WRAPPER=""
fi
if cargo test --workspace; then
echo "::endgroup::"
exit 0
fi
echo "::endgroup::"
echo "test failed on attempt ${attempt}"
if [ "${attempt}" -eq 1 ]; then
sccache --stop-server || true
sccache --start-server || true
fi
sleep 5
done
echo "test failed after 3 attempts"
exit 1
- run: sccache --show-stats || true
build-cortex:
name: Build cortex binary
needs: prepare
if: needs.prepare.outputs.build_cortex == 'true'
# runner-rust image already provides rust/cargo/clippy/rustfmt via
# dnf — no rustup install step needed.
runs-on: rust
env:
RUSTC_WRAPPER: sccache
SCCACHE_BUCKET: sccache
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
SCCACHE_REGION: auto
SCCACHE_S3_USE_SSL: "false"
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
# Escalation mirrors the lint/test jobs: retry → restart the
# sccache server → final attempt uncached. A sick cache costs
# build time, never the run.
- name: Build cortex (release, with sccache escalation)
run: |
for attempt in 1 2 3; do
echo "::group::build attempt ${attempt}"
if [ "${attempt}" -eq 3 ]; then
echo "final attempt: building without sccache"
export RUSTC_WRAPPER=""
fi
if cargo build --release -p cortex-cli; then
echo "::endgroup::"
sccache --show-stats || true
exit 0
fi
echo "::endgroup::"
echo "build failed on attempt ${attempt}"
if [ "${attempt}" -eq 1 ]; then
sccache --stop-server || true
sccache --start-server || true
fi
sleep 5
done
echo "build failed after 3 attempts"
exit 1
- name: Stage binary
run: |
mkdir --parents artifacts
cp target/release/cortex artifacts/cortex
./artifacts/cortex --version || true
- uses: actions/upload-artifact@v3
with:
name: cortex-fc43
path: artifacts/cortex
retention-days: 1
build-bench:
name: Build helexa-bench binary
needs: prepare
if: needs.prepare.outputs.build_bench == 'true'
# Pure-Rust, non-CUDA binary — same runner as cortex.
runs-on: rust
env:
RUSTC_WRAPPER: sccache
SCCACHE_BUCKET: sccache
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
SCCACHE_REGION: auto
SCCACHE_S3_USE_SSL: "false"
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Build helexa-bench (release, with sccache escalation)
run: |
# Stamp the SHA helexa-bench records as bench_sha against every
# run (option_env! in sweep.rs reads it at compile time).
export HELEXA_BUILD_SHA="$(git rev-parse HEAD)"
for attempt in 1 2 3; do
echo "::group::build attempt ${attempt}"
if [ "${attempt}" -eq 3 ]; then
echo "final attempt: building without sccache"
export RUSTC_WRAPPER=""
fi
if cargo build --release -p helexa-bench; then
echo "::endgroup::"
sccache --show-stats || true
exit 0
fi
echo "::endgroup::"
echo "build failed on attempt ${attempt}"
if [ "${attempt}" -eq 1 ]; then
sccache --stop-server || true
sccache --start-server || true
fi
sleep 5
done
echo "build failed after 3 attempts"
exit 1
- name: Stage binary
run: |
mkdir --parents artifacts
cp target/release/helexa-bench artifacts/helexa-bench
./artifacts/helexa-bench --version || true
- uses: actions/upload-artifact@v3
with:
name: bench-fc43
path: artifacts/helexa-bench
retention-days: 1
build-neuron:
name: Build neuron-${{ matrix.flavour }}
needs: prepare
if: needs.prepare.outputs.build_neuron == 'true'
strategy:
fail-fast: false
matrix:
include:
- flavour: ampere
compute_cap: "86"
runner: cuda-13.0
cuda_home: /usr/local/cuda-13.0
build_jobs: 8
nvcc_threads: 4
cargo_features: "cuda cudnn"
- flavour: ada
compute_cap: "89"
runner: cuda-13.0
cuda_home: /usr/local/cuda-13.0
build_jobs: 8
nvcc_threads: 4
cargo_features: "cuda cudnn"
- flavour: blackwell
compute_cap: "120"
runner: cuda-13.0
cuda_home: /usr/local/cuda-13.0
build_jobs: 8
nvcc_threads: 4
cargo_features: "cuda cudnn"
runs-on: ${{ matrix.runner }}
env:
SCCACHE_BUCKET: sccache
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
SCCACHE_REGION: auto
SCCACHE_S3_USE_SSL: "false"
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
# Escalation mirrors the lint/test jobs: retry → restart the
# sccache server → final attempt uncached.
#
# The CUDA image may or may not ship sccache — probe inside this
# step (NOT via GITHUB_ENV from a prior step, which this runner
# does not propagate; observed: probe step said "enabled", build
# ran unwrapped, server stats showed 4 compile requests). A
# missing binary degrades to an uncached build rather than
# failing cargo at `sccache rustc -vV`. The cache covers the
# ~600-crate host-side dep tree (the bulk of the 10-14 min
# build); rustc compilations are shared across all three
# flavours, so even one run seeds the next.
- name: Build neuron with CUDA (${{ matrix.flavour }})
run: |
set -ux
if command -v sccache >/dev/null 2>&1; then
export RUSTC_WRAPPER=sccache
sccache --start-server 2>/dev/null || true
echo "sccache enabled"
else
echo "sccache not on PATH — building uncached"
fi
export PATH="${{ matrix.cuda_home }}/bin:${PATH}"
export LD_LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LD_LIBRARY_PATH:-}"
export LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LIBRARY_PATH:-}"
# Pin the build SHA neuron reports from GET /version. The git
# fallback in build.rs would also work on a full checkout, but
# injecting the exact checked-out commit is unambiguous under
# shallow/detached states and makes the artifact self-describing.
export HELEXA_BUILD_SHA="$(git rev-parse HEAD)"
for attempt in 1 2 3; do
echo "::group::build attempt ${attempt}"
if [ "${attempt}" -eq 3 ]; then
echo "final attempt: building without sccache"
export RUSTC_WRAPPER=""
fi
if cargo build --release -p neuron --features "${{ matrix.cargo_features }}"; then
echo "::endgroup::"
command -v sccache >/dev/null 2>&1 && sccache --show-stats || true
exit 0
fi
echo "::endgroup::"
echo "build failed on attempt ${attempt}"
if [ "${attempt}" -eq 1 ] && command -v sccache >/dev/null 2>&1; then
sccache --stop-server || true
sccache --start-server || true
fi
sleep 5
done
echo "build failed after 3 attempts"
exit 1
env:
CUDA_COMPUTE_CAP: ${{ matrix.compute_cap }}
CARGO_BUILD_JOBS: ${{ matrix.build_jobs }}
NVCC_THREADS: ${{ matrix.nvcc_threads }}
- name: Stage binary
run: |
mkdir --parents artifacts
cp target/release/neuron artifacts/neuron-${{ matrix.flavour }}
file "artifacts/neuron-${{ matrix.flavour }}"
- uses: actions/upload-artifact@v3
with:
name: neuron-${{ matrix.flavour }}-fc43
path: artifacts/neuron-${{ matrix.flavour }}
retention-days: 1
package-cortex:
name: Package cortex RPM
needs: [prepare, build-cortex]
runs-on: rpm
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- uses: actions/download-artifact@v3
with:
name: cortex-fc43
path: artifacts/
- name: Build RPM
run: |
set -eux
rm -f ~/.rpmmacros
rpmdev-setuptree
cp artifacts/cortex ~/rpmbuild/SOURCES/
cp data/cortex.service ~/rpmbuild/SOURCES/
cp data/cortex-sysusers.conf ~/rpmbuild/SOURCES/
cp data/cortex-firewalld.xml ~/rpmbuild/SOURCES/
cp cortex.example.toml ~/rpmbuild/SOURCES/
cp models.example.toml ~/rpmbuild/SOURCES/
cp LICENSE ~/rpmbuild/SOURCES/
rpmbuild -bb rpm/cortex-prerelease.spec \
--define "cortex_version ${{ needs.prepare.outputs.version }}" \
--define "cortex_prerelease ${{ needs.prepare.outputs.release }}" \
--undefine dist \
--define "dist .fc43"
- uses: actions/upload-artifact@v3
with:
name: rpm-cortex-fc43
path: ~/rpmbuild/RPMS/x86_64/*.rpm
retention-days: 7
package-bench:
name: Package helexa-bench RPM
needs: [prepare, build-bench]
runs-on: rpm
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- uses: actions/download-artifact@v3
with:
name: bench-fc43
path: artifacts/
- name: Build RPM
run: |
set -eux
rm -f ~/.rpmmacros
rpmdev-setuptree
cp artifacts/helexa-bench ~/rpmbuild/SOURCES/
cp data/helexa-bench.service ~/rpmbuild/SOURCES/
cp data/helexa-bench-sysusers.conf ~/rpmbuild/SOURCES/
cp helexa-bench.example.toml ~/rpmbuild/SOURCES/
cp LICENSE ~/rpmbuild/SOURCES/
rpmbuild -bb rpm/helexa-bench-prerelease.spec \
--define "bench_version ${{ needs.prepare.outputs.version }}" \
--define "bench_prerelease ${{ needs.prepare.outputs.release }}" \
--undefine dist \
--define "dist .fc43"
- uses: actions/upload-artifact@v3
with:
name: rpm-bench-fc43
path: ~/rpmbuild/RPMS/x86_64/*.rpm
retention-days: 7
package-neuron:
name: Package helexa-neuron-${{ matrix.flavour }} RPM
needs: [prepare, build-neuron]
runs-on: rpm
strategy:
fail-fast: false
matrix:
include:
- flavour: ampere
- flavour: ada
- flavour: blackwell
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- uses: actions/download-artifact@v3
with:
name: neuron-${{ matrix.flavour }}-fc43
path: artifacts/
- name: Build RPM
run: |
set -eux
rm -f ~/.rpmmacros
rpmdev-setuptree
cp artifacts/neuron-${{ matrix.flavour }} ~/rpmbuild/SOURCES/
cp data/neuron.service ~/rpmbuild/SOURCES/
cp data/neuron-sysusers.conf ~/rpmbuild/SOURCES/
cp data/neuron-firewalld.xml ~/rpmbuild/SOURCES/
cp neuron.example.toml ~/rpmbuild/SOURCES/
cp LICENSE ~/rpmbuild/SOURCES/
rpmbuild -bb rpm/helexa-neuron-prerelease.spec \
--define "neuron_version ${{ needs.prepare.outputs.version }}" \
--define "neuron_flavour ${{ matrix.flavour }}" \
--define "neuron_prerelease ${{ needs.prepare.outputs.release }}" \
--undefine dist \
--define "dist .fc43"
- uses: actions/upload-artifact@v3
with:
name: rpm-neuron-${{ matrix.flavour }}-fc43
path: ~/rpmbuild/RPMS/x86_64/*.rpm
retention-days: 7
publish:
name: Publish to rpm.lair.cafe (unstable)
needs: [lint, test, package-cortex, package-neuron, package-bench]
# Runs when at least one package was built and nothing failed.
# lint/test may be skipped (docs-only refs never get here because
# no packages build), but a real failure in any blocks the
# fleet from receiving the RPMs.
if: >-
${{
!cancelled()
&& (needs.lint.result == 'success' || needs.lint.result == 'skipped')
&& (needs.test.result == 'success' || needs.test.result == 'skipped')
&& (needs.package-cortex.result == 'success' || needs.package-neuron.result == 'success' || needs.package-bench.result == 'success')
&& needs.package-cortex.result != 'failure'
&& needs.package-neuron.result != 'failure'
&& needs.package-bench.result != 'failure'
}}
runs-on: rpm
concurrency:
group: rpm-publish
cancel-in-progress: false
env:
RPM_REPO_HOST: oolon.kosherinata.internal
FEDORA_VERSION: "43"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Download all built RPMs
uses: actions/download-artifact@v3
with:
path: rpms/
pattern: rpm-*-fc43
- name: Flatten RPM artifacts
run: |
set -eux
find rpms/ -name '*.rpm' -exec mv --target-directory=rpms/ {} +
find rpms/ -mindepth 1 -type d -empty -delete
ls -la rpms/
- name: Check for sequoia-sq
run: |
if ! command -v sq &> /dev/null; then
echo "ERROR: sequoia-sq is not installed. Install with: sudo dnf install sequoia-sq"
exit 1
fi
- name: Import signing key
env:
# Pass secrets via env so values stay out of the rendered shell
# script (which Gitea includes in step logs). Template
# expansion of ${{ secrets.X }} inside `run:` writes the literal
# value into the script and depends on Gitea's log masker to
# scrub it — fragile for multi-line keys.
RPM_SIGNING_KEY: ${{ secrets.RPM_SIGNING_KEY }}
RPM_SIGNING_KEY_ID: ${{ secrets.RPM_SIGNING_KEY_ID }}
run: |
echo "$RPM_SIGNING_KEY" | gpg --batch --import
fpr=$(gpg --batch --with-colons --list-keys "$RPM_SIGNING_KEY_ID" | awk -F: '/^fpr:/ { print $10; exit }')
echo "${fpr}:6:" | gpg --batch --import-ownertrust
sed "s/@GPG_NAME@/$RPM_SIGNING_KEY_ID/" rpm/rpmmacros > ~/.rpmmacros
- name: Sign RPMs
run: |
set -eux
for rpm in rpms/*.rpm; do
echo "signing ${rpm}..."
rpm --addsign "${rpm}"
done
- name: Set up SSH for rsync
run: |
install --directory --mode 700 ~/.ssh
echo "${RSYNC_SSH_KEY}" | install --mode 600 /dev/stdin ~/.ssh/id_ed25519
env:
RSYNC_SSH_KEY: ${{ secrets.RSYNC_SSH_KEY }}
- name: Test SSH connectivity
run: |
ssh -o StrictHostKeyChecking=accept-new "gitea_ci@${RPM_REPO_HOST}" exit
- name: Ensure unstable repo directory exists
run: |
ssh "gitea_ci@${RPM_REPO_HOST}" \
"mkdir --parents /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable"
- name: Sync RPMs to unstable repo
run: |
rsync \
--archive \
--verbose \
--chmod D755,F644 \
rpms/*.rpm \
"gitea_ci@${RPM_REPO_HOST}:/var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/"
- name: Update unstable repo metadata
run: |
ssh "gitea_ci@${RPM_REPO_HOST}" \
"cd /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable && createrepo_c --update ."
- name: Generate packages.json manifest
run: |
scp script/generate-packages-json.py "gitea_ci@${RPM_REPO_HOST}:/tmp/"
ssh "gitea_ci@${RPM_REPO_HOST}" \
"python3 /tmp/generate-packages-json.py \
--repodata-dir /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/repodata \
--output /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/packages.json \
--base-url https://rpm.lair.cafe/fedora/${FEDORA_VERSION}/x86_64/unstable"

View File

@@ -1,26 +1,12 @@
name: CI
# Pushes to main are deliberately excluded: build-prerelease.yml runs
# its own lint/test jobs there (gating publish), and running both
# workflows on the same push made them queue against each other on the
# same runner labels — ~12 minutes of added latency per deploy. Feature
# branches, PRs to main, and release tags keep the full gate here.
on:
push:
branches-ignore: [main]
branches: ["**"]
tags: ["v*"]
pull_request:
branches: [main]
# Coalesce same-ref pushes; a newer push supersedes the in-flight run.
# (The old shared `cortex-runner-pool` group with build-prerelease.yml
# is gone — the workflows no longer trigger on the same refs, and
# ephemeral one-VM-per-job runners removed the shared-workspace race
# that group existed to serialize.)
concurrency:
group: ci-${{ github.ref }}
cancel-in-progress: true
env:
CARGO_INCREMENTAL: "0"
RUSTC_WRAPPER: sccache
@@ -30,163 +16,40 @@ env:
SCCACHE_S3_USE_SSL: "false"
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
# fmt, clippy, and test all run in parallel on the same `rust` runner
# and would otherwise share /root/.cache/act/<hash>/hostexecutor/target/,
# racing each other's cargo temp files (.tmpXXXXXX) and failing builds
# mid-compile. Give each job its own target directory so the invocations
# don't collide. sccache still backs the actual rustc cache, so the
# rebuild penalty is small.
CARGO_TARGET_DIR: target-${{ github.job }}
jobs:
fmt:
name: Format
runs-on: rust
check:
name: Format, lint, build, test
runs-on: fedora
steps:
- uses: actions/checkout@v4
- run: cargo fmt --check --all
clippy:
name: Clippy
runs-on: rust
steps:
- uses: actions/checkout@v4
# sccache failures come in two modes: transient races (a plain
# retry clears them) and a wedged/dead server, where every
# same-VM retry fails identically. Escalate: retry → restart the
# server → final attempt uncached. A sick cache costs build
# time, never the run. Keep in sync with build-prerelease.yml.
- name: Clippy (with sccache escalation)
- name: Ensure sccache with S3 support
env:
RUSTC_WRAPPER: ""
run: |
for attempt in 1 2 3; do
echo "::group::clippy attempt ${attempt}"
if [ "${attempt}" -eq 3 ]; then
echo "final attempt: building without sccache"
export RUSTC_WRAPPER=""
fi
if cargo clippy --workspace -- -D warnings; then
echo "::endgroup::"
exit 0
fi
echo "::endgroup::"
echo "clippy failed on attempt ${attempt}"
if [ "${attempt}" -eq 1 ]; then
sccache --stop-server || true
sccache --start-server || true
fi
sleep 5
done
echo "clippy failed after 3 attempts"
exit 1
- run: sccache --show-stats || true
test:
name: Test
runs-on: rust
steps:
- uses: actions/checkout@v4
# See the clippy job for the escalation rationale.
- name: Test (with sccache escalation)
run: |
for attempt in 1 2 3; do
echo "::group::test attempt ${attempt}"
if [ "${attempt}" -eq 3 ]; then
echo "final attempt: building without sccache"
export RUSTC_WRAPPER=""
fi
if cargo test --workspace; then
echo "::endgroup::"
exit 0
fi
echo "::endgroup::"
echo "test failed on attempt ${attempt}"
if [ "${attempt}" -eq 1 ]; then
sccache --stop-server || true
sccache --start-server || true
fi
sleep 5
done
echo "test failed after 3 attempts"
exit 1
- run: sccache --show-stats || true
# Type-check the CUDA-only code path. Borrow-check-only — we
# never run the tests here (the runner has no GPU). This catches
# the category of bug where a refactor compiles fine under the
# default feature set (which is what the `clippy` and `test` jobs
# exercise) but fails inside a `#[cfg(feature = "cuda")]` block.
# `runs-on: cuda-13.0` selects the runner that ships nvcc /
# cudarc's build prerequisites. The generic `rust` and `rpm`
# runners don't have them (the previous label `rpm` was tried
# first and tripped cudarc's `nvcc --version` build script —
# see commit history).
cuda-check:
name: CUDA type-check
runs-on: cuda-13.0
# The workflow-level env sets `RUSTC_WRAPPER: sccache`
# unconditionally, which hard-fails cargo if the CUDA image
# doesn't ship sccache. Clear it at job level; the "Enable
# sccache when available" step opts back in only after probing
# for the binary. SCCACHE_*/AWS creds stay set — harmless when
# the wrapper is off, required when it's on.
env:
RUSTC_WRAPPER: ""
# candle-kernels' build script falls back to `nvidia-smi` for
# compute-cap detection when this is unset — and the GPU-less
# builder image doesn't ship nvidia-smi. Any valid cap works for
# a borrow-check; the real per-flavour caps live in
# build-prerelease.yml's matrix.
CUDA_COMPUTE_CAP: "86"
steps:
- uses: actions/checkout@v4
# sccache is probed inside this step (NOT via GITHUB_ENV from a
# prior step — this runner doesn't propagate it; see
# build-prerelease.yml for the observed failure).
- name: cargo check --features cuda (with sccache escalation)
run: |
if command -v sccache >/dev/null 2>&1; then
export RUSTC_WRAPPER=sccache
sccache --start-server 2>/dev/null || true
echo "sccache enabled"
if sccache --version 2>/dev/null && sccache --show-stats 2>/dev/null; then
echo "sccache with S3 support already installed"
else
echo "sccache not on PATH — building uncached"
cargo install sccache --features s3 --locked
fi
# act launches the step shell without /etc/profile, so the
# gitea_runner user's inherited PATH lacks /usr/local/cuda-13.0/bin.
# cudarc's build.rs:157 shells out to `nvcc --version` (because
# the neuron crate enables cuda-version-from-build-system) and
# panics with ENOENT if nvcc isn't resolvable. build-prerelease.yml
# does the same export — keep them in sync.
export PATH="/usr/local/cuda-13.0/bin:${PATH}"
export LD_LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LD_LIBRARY_PATH:-}"
export LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LIBRARY_PATH:-}"
# Escalation mirrors the lint/test jobs: plain retry →
# sccache server restart → final attempt uncached.
for attempt in 1 2 3; do
echo "::group::cuda-check attempt ${attempt}"
if [ "${attempt}" -eq 3 ]; then
echo "final attempt: building without sccache"
export RUSTC_WRAPPER=""
fi
if cargo check -p neuron --features cuda --all-targets; then
echo "::endgroup::"
exit 0
fi
echo "::endgroup::"
echo "cuda-check failed on attempt ${attempt}"
if [ "${attempt}" -eq 1 ] && command -v sccache >/dev/null 2>&1; then
sccache --stop-server || true
sccache --start-server || true
fi
sleep 5
done
echo "cuda-check failed after 3 attempts"
exit 1
- name: Check formatting
run: cargo fmt --check --all
- name: Clippy
run: cargo clippy --workspace -- -D warnings
- name: Test
run: cargo test --workspace
- name: Show sccache stats
run: sccache --show-stats
srpm-cortex:
name: Build cortex SRPM
runs-on: rpm
needs: [fmt, clippy, test, cuda-check]
runs-on: fedora
needs: check
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v4
@@ -245,8 +108,8 @@ jobs:
srpm-neuron:
name: Build neuron SRPM
runs-on: rpm
needs: [fmt, clippy, test, cuda-check]
runs-on: fedora
needs: check
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/checkout@v4
@@ -305,7 +168,7 @@ jobs:
copr-cortex:
name: Publish cortex to COPR
runs-on: fedora-43
runs-on: fedora
needs: srpm-cortex
steps:
- name: Download SRPM
@@ -322,7 +185,7 @@ jobs:
copr-neuron:
name: Publish neuron to COPR
runs-on: fedora-43
runs-on: fedora
needs: srpm-neuron
steps:
- name: Download SRPM
@@ -339,7 +202,7 @@ jobs:
bump-version:
name: Bump version in source
runs-on: rust
runs-on: fedora
needs: [copr-cortex, copr-neuron]
steps:
- uses: actions/checkout@v4
@@ -382,6 +245,6 @@ jobs:
echo "Nothing to commit for ${VERSION}"
else
git commit -m "chore: bump version to ${VERSION}"
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/${{ github.repository }}.git"
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/helexa/cortex.git"
git push origin HEAD:main
fi

View File

@@ -1,134 +0,0 @@
name: deploy-dev
# Fast-path iteration deploy for a SINGLE neuron host: build one CUDA
# flavour, copy the raw binary to the host, restart neuron.service.
# Skips the other two flavours, all RPM packaging, signing, repo
# publish, and dnf — push-to-testable drops from ~20 min to roughly
# one CUDA build plus a service restart.
#
# This is a DEV convenience, not a release path:
# - the binary lands at /usr/bin/neuron *outside* RPM ownership;
# the next regular deploy.yml run reconciles the host back to the
# packaged binary (dnf sees the newer RPM and reinstalls). `rpm -V
# helexa-neuron-<flavour>` flagging a modified /usr/bin/neuron in
# the interim is expected.
# - nothing is published; other hosts are untouched.
# - requires the `install` sudoers rule from
# asset/sudoers.d/neuron-host.conf (re-run script/infra-setup.sh
# after updating it).
#
# Trigger from the Gitea UI: Actions → deploy-dev → Run workflow,
# pick the target host. Defaults to the ref you dispatch from, so it
# works from feature branches without touching main.
on:
workflow_dispatch:
inputs:
target:
description: "neuron host to deploy to"
required: true
type: choice
options: [beast, benjy, quadbrat]
default: beast
# One dev deploy at a time; a newer dispatch for the same host wins.
concurrency:
group: deploy-dev-${{ inputs.target }}
cancel-in-progress: true
env:
CARGO_INCREMENTAL: "0"
CARGO_TERM_COLOR: "always"
jobs:
build:
name: Build neuron (${{ inputs.target }})
runs-on: cuda-13.0
outputs:
flavour: ${{ steps.map.outputs.flavour }}
steps:
- uses: actions/checkout@v4
# host → flavour → compute cap. Keep in sync with the
# build-neuron matrix in build-prerelease.yml and the
# deploy-neurons matrix in deploy.yml.
- id: map
run: |
case "${{ inputs.target }}" in
beast) flavour=blackwell cap=120 ;;
benjy) flavour=ada cap=89 ;;
quadbrat) flavour=ampere cap=86 ;;
*) echo "unknown target ${{ inputs.target }}"; exit 1 ;;
esac
echo "flavour=${flavour}" >> "$GITHUB_OUTPUT"
echo "cap=${cap}" >> "$GITHUB_OUTPUT"
- name: Build neuron with CUDA
run: |
set -eux
export PATH="/usr/local/cuda-13.0/bin:${PATH}"
export LD_LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LD_LIBRARY_PATH:-}"
export LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LIBRARY_PATH:-}"
cargo build --release -p neuron --features "cuda cudnn"
env:
CUDA_COMPUTE_CAP: ${{ steps.map.outputs.cap }}
CARGO_BUILD_JOBS: "8"
NVCC_THREADS: "4"
- name: Stage binary
run: |
mkdir --parents artifacts
cp target/release/neuron artifacts/neuron-dev
file artifacts/neuron-dev
- uses: actions/upload-artifact@v3
with:
name: neuron-dev-${{ inputs.target }}
path: artifacts/neuron-dev
retention-days: 1
deploy:
name: Deploy to ${{ inputs.target }}
needs: build
runs-on: fedora-43
env:
DEPLOY_KEY: |
${{ secrets.RSYNC_SSH_KEY }}
TARGET_HOST: ${{ inputs.target }}.hanzalova.internal
steps:
- name: SSH init
run: |
mkdir -p ~/.ssh
echo "${DEPLOY_KEY}" > ~/.ssh/id_ed25519
chmod 600 ~/.ssh/id_ed25519
ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=accept-new \
"gitea_ci@${TARGET_HOST}" 'hostname -f'
- uses: actions/download-artifact@v3
with:
name: neuron-dev-${{ inputs.target }}
path: artifacts/
- name: Copy binary to host
run: |
scp artifacts/neuron-dev "gitea_ci@${TARGET_HOST}:/var/lib/gitea_ci/neuron-dev"
- name: Install binary and restart neuron.service
run: |
ssh "gitea_ci@${TARGET_HOST}" '
set -eu
if systemctl is-active --quiet neuron.service; then
sudo /usr/bin/systemctl stop neuron.service
fi
# Exact command form required by the sudoers rule in
# asset/sudoers.d/neuron-host.conf — change both together.
sudo /usr/bin/install -o root -g root -m 0755 /var/lib/gitea_ci/neuron-dev /usr/bin/neuron
sudo /usr/bin/systemctl start neuron.service
rm -f /var/lib/gitea_ci/neuron-dev'
- name: Capture neuron.service startup journal
if: always()
run: |
sleep 10
ssh "gitea_ci@${TARGET_HOST}" \
'journalctl --unit neuron.service -I --no-pager'

View File

@@ -1,252 +0,0 @@
name: deploy
# Roll the freshly-published unstable RPMs onto the helexa fleet:
# cortex on the gateway, helexa-neuron-<flavour> on each neuron host.
#
# Triggered automatically after `build-prerelease` succeeds (by which
# point the new RPMs are live on rpm.lair.cafe/unstable), and also
# re-runnable manually from the Gitea UI.
#
# Each host self-gates: if dnf sees no newer package than what is
# installed, the service is left alone — no stop, no restart, no model
# cold-load. Combined with build-prerelease's change detection this
# means a docs- or gateway-only push never restarts the neurons (a
# neuron restart costs ~5 min of 27B cold-load, see issue #1).
#
# Per-host one-time setup (gitea_ci user, authorized_keys, scoped
# sudoers drop-in) lives in script/infra-setup.sh — run that once per
# host before this workflow can succeed.
on:
workflow_run:
workflows: [build-prerelease]
types: [completed]
workflow_dispatch:
# Serialize deploys. Overlapping runs would race on dnf metadata
# refresh and service-restart timing; queueing keeps the fleet
# predictable. Don't cancel an in-flight deploy — a half-applied dnf
# transaction is worse than a slightly stale deploy.
concurrency:
group: deploy
cancel-in-progress: false
env:
DEPLOY_KEY: |
${{ secrets.RSYNC_SSH_KEY }}
jobs:
deploy-cortex:
runs-on: fedora-43
# Two trigger paths: manual dispatch always runs; workflow_run
# only runs if the upstream `build-prerelease` actually succeeded.
if: >-
${{
github.event_name == 'workflow_dispatch'
|| github.event.workflow_run.conclusion == 'success'
}}
steps:
- name: SSH init
run: |
mkdir -p ~/.ssh
echo "${DEPLOY_KEY}" > ~/.ssh/id_ed25519
chmod 600 ~/.ssh/id_ed25519
ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=accept-new \
gitea_ci@hanzalova.internal 'hostname -f'
# Gating compares `rpm -q` against the packages.json manifest the
# publish job maintains — NOT unprivileged `dnf check-update`,
# which proved unreliable as the gitea_ci user (hung on metadata
# locks on one host, silently reported "no updates" on others).
# An unreadable/unparsable manifest fails open: deploy proceeds.
- name: Deploy cortex (skips when already current)
run: |
ssh gitea_ci@hanzalova.internal 'bash -s' <<'DEPLOY'
set -eu
pkg=cortex
installed=$(rpm -q --qf '%{VERSION}-%{RELEASE}' "${pkg}" 2>/dev/null || echo "not-installed")
latest=$(curl -fsS --max-time 15 "https://rpm.lair.cafe/fedora/43/x86_64/unstable/packages.json" 2>/dev/null \
| python3 -c '
import json, sys
name = sys.argv[1]
cands = [p for p in json.load(sys.stdin)["packages"] if p.get("name") == name]
if cands:
p = max(cands, key=lambda p: p.get("buildTime", 0))
print(p["version"] + "-" + p["release"])
' "${pkg}" 2>/dev/null || true)
if [ -n "${latest}" ] && [ "${latest}" = "${installed}" ]; then
echo "${pkg}-${installed} already current — leaving service untouched"
exit 0
fi
echo "installed=${installed} published=${latest:-unknown} — deploying"
if systemctl is-active --quiet cortex.service; then
sudo /usr/bin/systemctl stop cortex.service
fi
if rpm -q "${pkg}" >/dev/null 2>&1; then
sudo /usr/bin/dnf upgrade --refresh --allowerasing -y cortex
else
sudo /usr/bin/dnf install --refresh --allowerasing -y cortex
fi
sudo /usr/bin/systemctl daemon-reload
sudo /usr/bin/systemctl start cortex.service
DEPLOY
# Wait for the service to either come up or wedge, then capture
# the latest-invocation journal. Runs even on prior failure so a
# failed start step still leaves a usable record in the deploy log.
- name: Capture cortex.service startup journal
if: always()
run: |
sleep 10
ssh gitea_ci@hanzalova.internal \
'journalctl --unit cortex.service -I --no-pager'
deploy-neurons:
needs: [deploy-cortex]
runs-on: fedora-43
strategy:
# One neuron failing must not cancel the others. Cortex is up
# already; a partial neuron deploy is strictly better than
# rolling back to zero.
fail-fast: false
matrix:
include:
# load_timeout: how long to wait for default_models to finish
# loading after a restart. beast cold-loads Qwen3.6-27B Q6K
# TP=2 (~5-6 min typical, see #1); benjy/quadbrat load small
# single-GPU models in well under a minute.
- host: beast.hanzalova.internal
flavour: blackwell
load_timeout: 900
- host: benjy.hanzalova.internal
flavour: ada
load_timeout: 300
- host: quadbrat.hanzalova.internal
flavour: ampere
load_timeout: 300
steps:
- name: SSH init
run: |
mkdir -p ~/.ssh
echo "${DEPLOY_KEY}" > ~/.ssh/id_ed25519
chmod 600 ~/.ssh/id_ed25519
ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=accept-new \
gitea_ci@${{ matrix.host }} 'hostname -f'
# See deploy-cortex for why gating uses the publish manifest and
# not unprivileged `dnf check-update`.
- name: Deploy helexa-neuron-${{ matrix.flavour }} (skips when already current)
run: |
ssh gitea_ci@${{ matrix.host }} 'bash -s' <<'DEPLOY'
set -eu
pkg=helexa-neuron-${{ matrix.flavour }}
installed=$(rpm -q --qf '%{VERSION}-%{RELEASE}' "${pkg}" 2>/dev/null || echo "not-installed")
latest=$(curl -fsS --max-time 15 "https://rpm.lair.cafe/fedora/43/x86_64/unstable/packages.json" 2>/dev/null \
| python3 -c '
import json, sys
name = sys.argv[1]
cands = [p for p in json.load(sys.stdin)["packages"] if p.get("name") == name]
if cands:
p = max(cands, key=lambda p: p.get("buildTime", 0))
print(p["version"] + "-" + p["release"])
' "${pkg}" 2>/dev/null || true)
if [ -n "${latest}" ] && [ "${latest}" = "${installed}" ]; then
echo "${pkg}-${installed} already current — leaving service untouched"
exit 0
fi
echo "installed=${installed} published=${latest:-unknown} — deploying"
if systemctl is-active --quiet neuron.service; then
sudo /usr/bin/systemctl stop neuron.service
fi
if rpm -q "${pkg}" >/dev/null 2>&1; then
sudo /usr/bin/dnf upgrade --refresh --allowerasing -y "${pkg}"
else
sudo /usr/bin/dnf install --refresh --allowerasing -y "${pkg}"
fi
sudo /usr/bin/systemctl daemon-reload
sudo /usr/bin/systemctl start neuron.service
# ── Post-deploy validation ────────────────────────────────
# A deploy only goes green if the neuron (a) finishes loading
# its default models and (b) answers a trivial prompt like an
# LLM should. Catches the class of bug where the binary
# starts fine but model load or inference is broken — which
# previously surfaced only when a human noticed. The wait
# polls /health activation (the structured source of the
# "loaded default model" journal line, plus per-model failure
# detail); the journal-capture step below still runs for
# forensics either way.
load_timeout=${{ matrix.load_timeout }}
echo "waiting for default models (timeout ${load_timeout}s)"
deadline=$(( $(date +%s) + load_timeout ))
health=""
while :; do
health=$(curl -fsS --max-time 5 http://localhost:13131/health 2>/dev/null || true)
state=$(printf %s "${health}" | python3 -c '
import json, sys
try:
print(json.load(sys.stdin).get("activation", {}).get("state", ""))
except Exception:
print("")
')
if [ "${state}" = "ready" ]; then
break
fi
if [ "$(date +%s)" -ge "${deadline}" ]; then
echo "FAIL: activation not ready within ${load_timeout}s (last state: ${state:-unreachable})"
exit 1
fi
sleep 10
done
model=$(printf %s "${health}" | python3 -c '
import json, sys
a = json.load(sys.stdin).get("activation", {})
failed = a.get("failed", [])
if failed:
for f in failed:
msg = "FAILED " + str(f.get("model_id")) + ": " + str(f.get("error", ""))[:400]
sys.stderr.write(msg + chr(10))
sys.exit(1)
completed = a.get("completed", [])
print(completed[0] if completed else "")
')
if [ -z "${model}" ]; then
echo "no default models configured — skipping LLM probe"
exit 0
fi
echo "LLM probe against ${model}"
probe_body=$(printf '{"model":"%s","messages":[{"role":"user","content":"Reply with exactly one word: pineapple"}],"max_tokens":512,"temperature":0}' "${model}")
resp=$(curl -fsS --max-time 180 -H "content-type: application/json" \
-d "${probe_body}" http://localhost:13131/v1/chat/completions) || {
echo "FAIL: probe request errored"
exit 1
}
if printf %s "${resp}" | grep -qi pineapple; then
echo "LLM probe passed"
else
echo "FAIL: probe response missing expected token"
printf %s "${resp}" | head -c 2000
echo
exit 1
fi
DEPLOY
- name: Ensure firewalld allows helexa-neuron
run: |
ssh gitea_ci@${{ matrix.host }} '
if ! sudo /usr/bin/firewall-cmd --query-service=helexa-neuron --quiet 2>/dev/null; then
sudo /usr/bin/firewall-cmd --add-service=helexa-neuron --permanent
sudo /usr/bin/firewall-cmd --reload
fi'
# Wait for the service to either come up or wedge, then capture
# the latest-invocation journal. Runs even on prior failure so a
# failed start step still leaves a usable record in the deploy log.
- name: Capture neuron.service startup journal
if: always()
run: |
sleep 10
ssh gitea_ci@${{ matrix.host }} \
'journalctl --unit neuron.service -I --no-pager'

3
.gitignore vendored
View File

@@ -4,7 +4,4 @@
.idea/
.vscode/
cortex.toml
models.toml
doc/plan/*
/target-cuda/
.claude/

237
CLAUDE.md
View File

@@ -1,26 +1,16 @@
# CLAUDE.md — helexa
# CLAUDE.md — cortex
## Project overview
helexa is a self-hosted LLM serving stack for multi-node GPU inference
clusters. It has two components:
- **cortex** — the per-operator control plane and LLM proxy. A Rust
reverse-proxy that sits in front of the fleet and presents a unified
OpenAI + Anthropic compatible API surface. It handles model routing,
lifecycle management (load/unload/evict), request translation, and
metrics collection.
- **neuron** — the per-host LLM harness. One instance runs on every GPU
host, serving candle-based in-process inference and managing local
hardware discovery and model lifecycle.
(Historical note: cortex originally proxied to mistral.rs nodes; neuron
replaced that — see the 2026-05-18 candle-native addendum below.)
cortex is a Rust reverse-proxy that sits in front of multiple
mistral.rs inference nodes and presents a unified OpenAI + Anthropic
compatible API surface. It handles model routing, lifecycle management
(load/unload/evict), request translation, and metrics collection.
## Repository layout
```
helexa/
cortex/
├── Cargo.toml # workspace root
├── cortex.toml # example gateway config
├── README.md
@@ -94,63 +84,6 @@ Per-request: model, node, prompt_tokens, completion_tokens, total_tokens,
tok_per_sec, time_to_first_token_ms, total_latency_ms.
Exposed as Prometheus histograms/counters on a separate port.
### Per-device worker thread (neuron)
The neuron daemon dedicates one OS thread per CUDA device it loads
onto. That thread binds the device's `CudaContext` once at startup and
owns it for the daemon's lifetime; every model load, forward step,
KV-cache reset, VRAM query, NCCL init/sanity, NCCL all_reduce, and
model drop on that device routes through this thread via a
`std::sync::mpsc` job channel. Replies cross back via
`tokio::sync::oneshot`.
Three properties this gives us, in order of weight:
1. **Context locality.** cudarc binds the CUDA context per OS thread
via `cuCtxSetCurrent`. Before this refactor, ad-hoc
`tokio::task::spawn_blocking` calls bound the context onto a
different thread per request — and `device_vram_mb()` from an
async task bound it onto whichever tokio worker happened to be
running. Pinning the context to one named thread ends that.
2. **Drop safety.** Every `CudaSlice` in a `Tensor`, every
`cudarc::nccl::Comm`, and the `CudaContext` itself call `cuMemFree` /
`ncclCommDestroy` / `cuCtxDestroy` during `Drop` — and require the
right context current. With the worker owning the model slab,
`Drop` always runs on the right thread. The cudarc Drop constraint
is structurally enforced.
3. **Poisoning blast radius.** When a CUDA driver error makes the
context unrecoverable, the poison flag lives on the
`DeviceWorkerHandle` itself. Subsequent `submit()` calls fast-reject
at the channel boundary with a clear "device worker is poisoned"
error before any further CUDA work is attempted. The thread doesn't
exit (dropping the slab would re-touch the broken context) — it
enters a drain-only mode and replies error to everything until the
daemon restarts.
Tensors never escape the worker thread alive. Inference replies carry
`Vec<f32>` CPU-side logits; the async caller wraps them in a CPU
candle tensor and runs `apply_repeat_penalty` + `LogitsProcessor::sample`
without ever rebinding the device context. Sampled tokens come back as
`u32`; VRAM queries as `(u64, u64)`. The opaque `ArchHandle(u64)` and
`TpHandle(u64)` are the only "references" callers hold to loaded
models — they're indices into the worker's state slab, not pointers.
The TP worker subprocesses in `harness/tp/worker.rs` are the same
pattern out-of-process — a dedicated context-owning process per
non-zero NCCL rank. The in-process worker in `harness/device_worker/`
brings the discipline to rank 0.
CPU loads (`Device::Cpu` fallback when CUDA is unavailable) keep the
legacy `tokio::task::spawn_blocking + Arc<Mutex<ModelArch>>` path —
there's no context to own and the channel hop would only add latency.
Four `spawn_blocking` references in `harness/candle.rs` are deliberate
CPU fallback.
Canonical narrative lives in
`crates/neuron/src/harness/device_worker/mod.rs`'s module
doc-comment; touch points (the `Job` enum, the dispatch handlers, the
`DeviceWorkerState` struct) are in the sibling `jobs.rs` and
`dispatch.rs`.
## Tech stack
- **Rust 2024 edition** — workspace with 4 crates
@@ -558,7 +491,7 @@ and the hardcoded `vram_mb` per node.
## Revised repository layout
```
helexa/
cortex/
├── Cargo.toml
├── cortex.toml # gateway config (neurons only)
├── models.toml # model catalogue
@@ -683,120 +616,58 @@ dnf install cortex # gateway host
dnf install helexa-neuron # GPU nodes
```
## 2026-05-18 addendum: candle-native pivot
### Phase 11: llama.cpp harness stub
Phases 11 (llama.cpp harness) and 12 (mistral.rs COPR) below are
**superseded**. The project no longer treats mistral.rs or llama.cpp as
dependencies — both are conceptually out of scope. neuron becomes a
candle-native inference daemon, with `Harness` retained as an
internal seam for adding future engines (vision/audio/diffusion) but
its only implementation being in-process candle.
**Goal:** Prove the harness abstraction works with a second engine.
The full staged plan for this pivot lives at
`~/.claude/plans/create-a-more-aggressive-calm-naur.md`. Summary:
**Steps:**
1. `crates/neuron/src/harness/llamacpp.rs` — implement the `Harness`
trait for llama.cpp's `llama-server`.
- `start()` — launch `llama-server` with the correct model path,
`--port`, `--n-gpu-layers`, `--tensor-split` args. Track the
child process.
- `stop()` — send SIGTERM to the child process.
- `list_models()` — llama-server serves one model per process, so
return a single-element list.
- `load_model()` — start a new llama-server process for this model.
- `unload_model()` — stop the process.
- `inference_endpoint()` — return `http://localhost:{assigned_port}`.
2. Port allocation: neuron assigns ports from a range (e.g. 8100-8199)
to llama-server instances.
3. Register in `HarnessRegistry` when configured:
```toml
[[harnesses]]
name = "llamacpp"
binary = "/usr/local/bin/llama-server"
port_range = [8100, 8199]
```
4. Tests: mock llama-server (simple HTTP server returning canned
responses), test load/unload/endpoint lifecycle.
- **Stage 1 (this commit):** delete `mistralrs.rs` and `llamacpp.rs`,
scaffold inert `CandleHarness`, drop `endpoint`/`systemd_unit` from
`HarnessConfig`, default no-op `start`/`stop` on the `Harness` trait.
- **Stages 24:** wire up candle model load/unload (quantized Qwen3
first), add OpenAI-compatible inference endpoint in neuron, then SSE
streaming.
- **Stages 56:** load-on-activation (default models in config) and
unload-on-deactivation (graceful shutdown).
- **Stages 78:** multi-GPU tensor parallelism and broader model/quant
coverage.
**Done when:** A model with `harness = "llamacpp"` in `models.toml` can
be loaded and served through cortex. Tests pass with mock llama-server.
Sections of this document that describe mistral.rs HTTP behaviour
("mistral.rs API gotchas") are retained as historical context for
Phases 110 — they document what was true while the project depended
on mistral.rs. They do not describe current behaviour.
### Phase 12 (lower priority): mistral.rs COPR packaging
---
**Goal:** Fedora RPMs for mistral.rs built against specific CUDA versions.
### Phase 11 (superseded): llama.cpp harness stub
**Steps:**
1. `mistralrs-cuda.spec` — RPM spec that clones a pinned mistral.rs git
tag, builds with `--features cuda`, links against the system CUDA
toolkit. Produces `mistralrs-cuda13-server` (CUDA 13.x / sm_120) and
`mistralrs-cuda12-server` (CUDA 12.x / sm_89). Install binary to
`/usr/local/bin/mistralrs`.
2. COPR build config: enable the NVIDIA CUDA repo as a build dependency.
Pin the CUDA toolkit version in `BuildRequires`.
3. Gitea Actions or manual workflow: bump the mistral.rs tag in the spec,
trigger COPR rebuild.
4. neuron's mistralrs harness config references which binary/package
provides the mistral.rs binary. neuron could warn at startup if the
installed mistral.rs CUDA version doesn't match the discovered driver.
~~Originally planned as a second engine to prove the harness
abstraction.~~ Replaced by the candle harness work in the 2026-05-18
addendum above. llama.cpp's any-model/any-hardware breadth is no
longer in scope for helexa.
**Done when:** `dnf install mistralrs-cuda13-server` on beast provides a
working `mistralrs` binary built for Blackwell GPUs. `dnf install
mistralrs-cuda12-server` on benjy provides one built for Ada GPUs.
### Phase 12 (superseded): mistral.rs COPR packaging
~~Originally planned to ship CUDA-versioned mistral.rs RPMs.~~ Replaced
by the candle harness work in the 2026-05-18 addendum above. With
mistral.rs out of the dependency tree, there is nothing to package.
## 2026-05-27 addendum: per-device worker thread
Replaced the ad-hoc `tokio::task::spawn_blocking` pattern that drove
every leader-side CUDA op with one dedicated OS thread per CUDA device,
permanently bound to that device's `CudaContext`. All leader-side
inference work (GGUF + dense + TP shard load, forward, kv-cache clear,
NCCL init/sanity, NCCL all_reduce, VRAM query, model drop) routes
through the worker via a `std::sync::mpsc` channel; tensors never
escape the worker thread alive. See "Per-device worker thread (neuron)"
above and `crates/neuron/src/harness/device_worker/mod.rs` for the
canonical narrative.
Motivated by the 2026-05-26 silent-hang on beast: a CUDA OOM cascade
poisoned the device context on whichever spawn_blocking thread caught
it, and subsequent requests stalled invisibly on the pool lock. After
the refactor, the same failure mode shows up in journalctl as
`prefill sample failed; logits unhealthy nan: 248320/248320` followed
by `failed, model marked poisoned`. The thread stays alive and rejects
subsequent requests at the channel boundary.
Landed in four PRs:
- **Phase 1** (`081b532`) — device_worker module + 8 VRAM-query sites
route through the worker. CPU build only; smoke on beast confirmed
a persistent `cuda-dev-0` thread.
- **Phase 2** (`b179204`) — single-GPU forward + clear_kv + drop via
the worker. `LoadedModel.arch_handle: Option<ArchHandle>` replaces
`Arc<Mutex<ModelArch>>` for CUDA loads. CPU keeps the legacy path.
- **Phase 3** (`76ab24d`) — TP forward + NCCL init/sanity + leader
KV-clear routed through the worker. `WorkerPool.leader_nccl` moves
into the worker's state. `TpLoadedModel.leader_handle: TpHandle`
replaces `Arc<Mutex<TpLeaderModel>>`. CUDA-only TP smoke deferred to
next deploy.
- **Phase 4** (`b4f3576`) — GGUF + dense + TP shard loads move onto
the worker. The `Job::TransferIn` / `Job::CloneLeaderComm` bridges
from Phases 2/3 deleted; `SendComm` newtype no longer needed in the
load path. `grep -rn spawn_blocking crates/neuron/src/harness/`
returns only deliberate CPU-fallback hits after this PR.
## 2026-06-13 addendum: build metadata + helexa-bench
Two coupled additions so fleet performance can be tracked automatically
across neuron updates instead of by hand-running `script/bench.py` and
editing `doc/benchmarks.md`.
**neuron build metadata + `GET /version`.** neuron's `build.rs` now also
captures build identity (`HELEXA_GIT_SHA` — preferring a CI/RPM-injected
`HELEXA_BUILD_SHA`, falling back to git, else `unknown` — plus dirty
flag, build timestamp, rustc version, profile, enabled cargo features,
and a best-effort `candle-core` version from `Cargo.lock`). These are
exposed as `cortex_core::build_info::BuildInfo` (new module) from a new
`GET /version` endpoint (`neuron/src/version.rs`, wired in `api.rs`) and
in clap's `--version` long form. The SHA is injected in CI
(`build-prerelease.yml` build-neuron step: `export HELEXA_BUILD_SHA=$(git
rev-parse HEAD)`) and via `--define helexa_commit` in the source-build
spec, so tarball-built RPMs report the real SHA. `/version` is now the
canonical "which build is live" probe (supersedes the per-host RPM-sha
check in the fleet-validation flow).
**`crates/helexa-bench`** — a new binary: a continuous, version-aware
benchmark harness (one systemd unit, typically on the metrics host). It
hits each neuron **directly** on `:13131`, exercises each **warm**
(`status == "loaded"`) model with an extensible `Scenario` suite (phase
1: the chat-latency family ported verbatim from `bench.py` — synthetic
128/4096-tok prompts, `/no_think`, streamed TTFT + decode-window
tok/s), and records each run into a SQLite system-of-record stamped with
the neuron's full `BuildInfo`. The loop is **version-aware**: it skips
any (target, build SHA, model, scenario) cell already at
`samples_per_version`, so a steady fleet costs only cheap `/version` +
`/models` polls until a new SHA ships. `helexa-bench report` regenerates
the `benchmarks.md`-style table from the DB. `kind = "openai"` targets
(mistral.rs/llama.cpp comparison) are scaffolded but not yet wired.
Packaged as the `helexa-bench` RPM (prebuilt-binary spec, outbound-only
so no firewalld service) via the same `build-prerelease.yml` pipeline.
This is a separate repo/spec — not part of the cortex workspace — but
tightly coupled operationally. Track it as a sibling project.

2626
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -5,15 +5,13 @@ members = [
"crates/cortex-gateway",
"crates/cortex-cli",
"crates/neuron",
"crates/helexa-acp",
"crates/helexa-bench",
]
[workspace.package]
version = "0.1.16"
version = "0.1.12"
edition = "2024"
license = "GPL-3.0-or-later"
repository = "https://git.lair.cafe/helexa/helexa"
repository = "https://git.lair.cafe/helexa/cortex"
[workspace.dependencies]
# async runtime
@@ -29,7 +27,7 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
# http client (for proxying to neuron backends)
# http client (for proxying to mistralrs backends)
reqwest = { version = "0.12", features = ["json", "stream"] }
# observability
@@ -62,12 +60,3 @@ eventsource-stream = "0.2"
# workspace crates
cortex-core = { path = "crates/cortex-core" }
cortex-gateway = { path = "crates/cortex-gateway" }
# Patched cudarc (affects neuron's 0.19.x only; candle's 0.17.x is
# untouched since the fork is 0.19.7 and doesn't satisfy a 0.17 req). Adds
# Comm::abort / get_async_error / raw comm() — needed for #17 Stage 2 TP
# hang-recovery (abort a wedged collective from another thread, then
# rebuild the comm). Pinned to a fork revision pending upstream review
# (grenade/cudarc @ nccl-comm-abort).
[patch.crates-io]
cudarc = { git = "https://github.com/grenade/cudarc", rev = "63327a256059f8252641ae46c6bb9eefe707f382" }

223
README.md
View File

@@ -1,68 +1,24 @@
# helexa
# cortex
**Near-frontier AI for mortals.**
A Rust reverse-proxy and fleet management layer for multi-node
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
helexa is a self-hosted LLM serving stack, written in Rust, for people
who run open-weight models on their own consumer GPUs. It has two
components:
## Problem
- **cortex** — the per-operator control plane and LLM proxy. It sits in
front of your GPU fleet and presents a unified OpenAI + Anthropic
compatible API surface, handling model routing, lifecycle management
(load / unload / evict), request translation, and metrics.
- **neuron** — the per-host LLM harness. One instance runs on every GPU
host, serving candle-based in-process inference and managing local
hardware discovery and model lifecycle.
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
model affinities) requires a unified API surface that:
## Why
Two principles constrain everything in this repository:
1. **Frontier or close to it.** helexa serves the open-weight models
that get nearest to frontier capability — not every architecture
ever published.
2. **Consumer hardware.** Everything must run on the cards mortals can
actually buy: a 3060 here, a 4090 there, a 5090 if you got lucky.
Mixed VRAM tiers across mismatched boxes are the expected topology,
not a degraded case.
GPU acquisition is harder than it was a year ago, and the gap between
what cloud providers charge and what your own silicon costs keeps
widening. The intersection of those two principles — near-frontier
models, squeezed onto hardware you own — is helexa's entire niche.
The secondary objective is **predictable consumption**. If you own the
hardware, your tooling shouldn't break because a cloud provider changed
billing, deprecated a model, or reshaped an API. cortex's OpenAI and
Anthropic surfaces are a stability contract: point your editor, agent,
or CLI at it once, and it keeps working.
## What helexa is not
This is an intentionally different path from vLLM, SGLang, and peers —
not a smaller version of them. Out of scope, permanently:
- Any-model breadth. Architectures are ported because they're at or
near the frontier, not to complete a compatibility matrix.
- Datacenter-class scheduling. No sophisticated continuous-batching /
paged-attention machinery — the workload is a handful of operators
and their agents, not 200 QPS.
- Wrapping external inference engines. neuron builds directly on
[candle](https://github.com/huggingface/candle); every model
architecture it serves is implemented in this repository, ported
against the HuggingFace reference.
One thing that is *not* a principle: CUDA exclusivity. All high-end
consumer hardware is in scope. helexa is CUDA-only today because
that's the hardware on the bench — nothing ships untested — and ROCm
or other consumer accelerators join as soon as there's real hardware
to build against.
In scope, and where the engineering effort goes: aggressive
quantization (GGUF Q4_K_M / Q6_K / Q8_0), NCCL tensor parallelism
across heterogeneous consumer GPUs, careful CUDA failure handling, and
single-request latency — the performance that one operator at a
keyboard actually feels.
- Presents a **single `/v1/models` catalogue** merging every model across every
node.
- **Routes requests** to the correct node based on where a model is loaded (or
*can* be loaded).
- Manages **model lifecycle** — unload cold models, reload on demand, pin
critical ones — using the mistral.rs
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
- Translates between **OpenAI and Anthropic** request/response envelopes so
every client in the homelab speaks whichever dialect it prefers.
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
them as Prometheus counters/histograms.
## Architecture
@@ -72,79 +28,65 @@ keyboard actually feels.
└──────┬───────┘ └─────┬────┘ └──────┬─────┘ └──────┬─────┘
│ │ │ │
└────────────────┴──────┬───────┴───────────────┘
OpenAI + Anthropic APIs
┌──────────▼──────────┐
cortex
│ (cortex-gateway) │
│ cortex │
(cortex-gateway)
│ │
│ Router · Metrics │
│ Evictor · Translate│
└──┬──────┬────────┬──┘
│ │ │
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
neuron │ │ neuron │ │ neuron
:13131 │ │ :13131 │ │ :13131
candle │ │ candle │ │ candle
gpu-large │ │gpu-med │ │ gpu-small
mistralrs │ │mistral │ │ mistralrs
serve │ │rs serve│ │ serve
│ :8080 │ │ :8080 │ │ :8080 │
└───────────┘ └────────┘ └───────────┘
private network (.internal)
```
cortex discovers each neuron's hardware (devices, VRAM, compute
capability) at runtime and matches it against a model catalogue
(`models.toml`) to decide placement: which models fit where, what to
evict when VRAM is tight, where to route a request right now. Adding a
GPU host to the fleet is one `[[neurons]]` entry — no device specs in
config.
### Crates
| Crate | Purpose |
|---|---|
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic envelopes, harness trait, discovery types |
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, poller, metrics exporter |
| `neuron` | Per-host daemon: GPU discovery, in-process candle inference, NCCL tensor parallelism, model lifecycle API |
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
| `helexa-acp` | Agent Client Protocol bridge — connects ACP editors (Zed, etc.) to any OpenAI-compatible endpoint, cortex by default |
## The engine
## Node setup
neuron runs inference in-process on candle — there is no external
inference server to babysit. The parts that earn their keep:
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
declared but start **unloaded** — mistral.rs lazy-loads on first request and
the gateway can explicitly unload/reload via the HTTP API.
- **Per-device worker threads.** Every CUDA device gets one dedicated
OS thread that owns its CUDA context for the daemon's lifetime. All
loads, forward passes, KV-cache resets, NCCL collectives, VRAM
queries, and unloads route through it; tensors never escape it
alive. Context binding is pinned to a known thread, the CUDA `Drop`
contract is structurally safe, and a driver error poisons one worker
— visibly — instead of hanging the whole process.
- **Tensor parallelism on consumer cards.** Megatron-style row/column
parallel layers with NCCL all-reduce, spanning the mismatched GPUs
you actually have. A step watchdog aborts wedged collectives instead
of letting a request hang forever.
- **Current model focus: the Qwen3 family** — dense and GGUF-quantized,
including the hybrid linear-attention (Gated DeltaNet) generation.
Vision support is in progress. Each architecture is ported against
its HuggingFace reference implementation.
Example node systemd unit:
See `CLAUDE.md` for design rationale and
`crates/neuron/src/harness/device_worker/` for the worker narrative.
```ini
# /etc/systemd/system/mistralrs.service
[Unit]
Description=mistral.rs inference server
After=network-online.target
Wants=network-online.target
## Install
[Service]
Type=simple
ExecStart=/usr/local/bin/mistralrs serve \
--from-config /etc/mistralrs/config.toml \
--port 8080
Restart=on-failure
RestartSec=5
Environment=CUDA_VISIBLE_DEVICES=0,1
Pre-built RPMs for Fedora:
```sh
dnf copr enable helexa/helexa
dnf install cortex # on the gateway host
dnf install helexa-neuron # on each GPU host
systemctl enable --now cortex # or neuron, respectively
[Install]
WantedBy=multi-user.target
```
## Configure
## Gateway config
```toml
# /etc/cortex/cortex.toml
# cortex.toml
[gateway]
listen = "0.0.0.0:31313"
metrics_listen = "0.0.0.0:31314"
@@ -153,38 +95,35 @@ metrics_listen = "0.0.0.0:31314"
strategy = "lru" # lru | priority
defrag_after_cycles = 50
[[neurons]]
name = "beast"
endpoint = "http://beast.internal:13131"
[[nodes]]
name = "gpu-large"
endpoint = "http://gpu-large.internal:8080"
vram_mb = 49_152 # e.g. 2x RTX 4090
pinned = ["your-org/large-model"]
[[neurons]]
name = "benjy"
endpoint = "http://benjy.internal:13131"
[[nodes]]
name = "gpu-medium"
endpoint = "http://gpu-medium.internal:8080"
vram_mb = 24_576 # e.g. RTX 4090
pinned = ["your-org/medium-model"]
[[nodes]]
name = "gpu-small"
endpoint = "http://gpu-small.internal:8080"
vram_mb = 12_288 # e.g. RTX 3060
pinned = ["your-org/embedding-model"]
```
Model placement profiles (VRAM requirements, quant, device minimums,
pinning) live in `models.toml` — see `models.example.toml`.
## Run
```sh
# start the gateway
cortex serve --config /etc/cortex/cortex.toml
# check fleet status
cortex status
# one catalogue across every node
curl http://localhost:31313/v1/models
```
## Build from source
## Building
```sh
cargo build --release
```
CI runs on every push; keep it green locally:
## CI
Every push triggers format, lint, and test checks. Ensure these pass
locally before pushing:
```sh
cargo fmt --check --all # must be clean
@@ -192,18 +131,20 @@ cargo clippy --workspace -- -D warnings # warnings are errors
cargo test --workspace # all tests must pass
```
Tagged releases (`v*`) build SRPMs for `cortex` and `helexa-neuron`
and publish to COPR.
Tagged releases (`v*`) additionally build an SRPM and publish to COPR.
## Status
## Running
Pre-1.0 and moving fast. The gateway path (routing, eviction,
translation, metrics) is stable and tested; the candle-native engine
is under active development — expect the supported-model list to track
the open-weight frontier, deliberately narrowly.
```sh
# start the gateway
cortex serve --config cortex.toml
Development happens at <https://git.lair.cafe/helexa/helexa>;
<https://github.com/helexa-ai/helexa> is a read-only mirror.
# check fleet status
cortex status
# list all models across nodes
curl http://localhost:31313/v1/models
```
## License

View File

@@ -1,24 +0,0 @@
# neuron.toml for beast.hanzalova.internal
#
# 2x RTX 5090 (32 GB each) — TP-2 capable. Pre-warms Qwen3.6-27B with
# q5k ISQ across both GPUs at activation, matching the validate-neuron
# invocation: `validate-neuron.sh beast.hanzalova.internal
# Qwen/Qwen3.6-27B q5k 2`.
#
# Synced to /etc/neuron/neuron.toml by script/infra-setup.sh. Edits
# take effect after the next deploy workflow run restarts the service
# (default_models is read at activation).
port = 13131
[[harnesses]]
name = "candle"
[harness.candle]
[[default_models]]
model_id = "Qwen/Qwen3.6-27B"
harness = "candle"
quant = "q6k"
tensor_parallel = 2
devices = [0, 1]

View File

@@ -1,19 +0,0 @@
# neuron.toml for benjy.hanzalova.internal
#
# 1x RTX 4090 (24 GB) — largest single-GPU host on the fleet. Pre-warms
# Qwen3-8B (bf16, ~18 GB), leaving ~6 GB for KV cache + activations on
# moderate-length contexts.
#
# Synced to /etc/neuron/neuron.toml by script/infra-setup.sh.
port = 13131
[[harnesses]]
name = "candle"
[harness.candle]
[[default_models]]
model_id = "Qwen/Qwen3-8B"
harness = "candle"
devices = [0]

View File

@@ -1,19 +0,0 @@
# neuron.toml for quadbrat.hanzalova.internal
#
# 1x RTX 3060 (12 GB) — small / quantised tier. Pre-warms Qwen3-1.7B
# (bf16, ~4 GB), leaving ~7 GB for KV cache so long contexts on a small
# model still have plenty of room.
#
# Synced to /etc/neuron/neuron.toml by script/infra-setup.sh.
port = 13131
[[harnesses]]
name = "candle"
[harness.candle]
[[default_models]]
model_id = "Qwen/Qwen3-1.7B"
harness = "candle"
devices = [0]

View File

@@ -1,20 +0,0 @@
# Install on the cortex gateway host as /etc/sudoers.d/helexa_gitea_ci
# (owner root:root, mode 0440). Required by .gitea/workflows/deploy.yml,
# which SSHes as gitea_ci@<gateway> to roll out cortex package upgrades
# and config changes.
#
# Filename convention `helexa_gitea_ci` (vs bare `gitea_ci`) so other
# helexa-org apps can drop their own sudoers files on the same host
# without overwriting this one.
gitea_ci ALL=(root) NOPASSWD: /usr/bin/rsync * /etc/cortex/cortex.toml
gitea_ci ALL=(root) NOPASSWD: /usr/bin/rsync * /etc/cortex/models.toml
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl start cortex.service
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl stop cortex.service
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl daemon-reload
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y cortex
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y cortex
# sudoers reserves `:` and `=` and requires `\` escaping inside command
# arguments — without it visudo errors at the first `:` in `https://`.
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager addrepo --from-repofile\=https\://rpm.lair.cafe/lair-cafe-unstable.repo
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager setopt lair-cafe-unstable.enabled\=1

View File

@@ -1,38 +0,0 @@
# Install on every neuron host as /etc/sudoers.d/helexa_gitea_ci
# (owner root:root, mode 0440). Required by .gitea/workflows/deploy.yml,
# which SSHes as gitea_ci@<neuron-host> to roll out helexa-neuron-<flavour>
# package upgrades and config changes.
#
# Filename convention `helexa_gitea_ci` (vs bare `gitea_ci`) so other
# helexa-org apps can drop their own sudoers files on the same host
# without overwriting this one.
#
# All three CUDA flavours are listed because a host's flavour can change
# (e.g. GPU swap) and we don't want the sudoers file to need to change
# in lockstep. Only one flavour can be installed at a time (the packages
# Conflict: with each other), so the attack surface is bounded to "wrong
# flavour installed" — vandalism, not privilege escalation.
gitea_ci ALL=(root) NOPASSWD: /usr/bin/rsync * /etc/neuron/neuron.toml
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl start neuron.service
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl stop neuron.service
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl daemon-reload
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-ampere
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-ampere
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-ada
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-ada
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-blackwell
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-blackwell
# sudoers reserves `:` and `=` and requires `\` escaping inside command
# arguments — without it visudo errors at the first `:` in `https://`.
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager addrepo --from-repofile\=https\://rpm.lair.cafe/lair-cafe-unstable.repo
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager setopt lair-cafe-unstable.enabled\=1
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager addrepo --from-repofile\=https\://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install -y libcudnn9-cuda-13
gitea_ci ALL=(root) NOPASSWD: /usr/bin/firewall-cmd --add-service=helexa-neuron --permanent
gitea_ci ALL=(root) NOPASSWD: /usr/bin/firewall-cmd --reload
# deploy-dev.yml fast path: install a freshly-built dev binary over the
# packaged one. Exact source path + args; the workflow must use this
# command form verbatim. The next deploy.yml run reconciles the host
# back to the RPM-owned binary.
gitea_ci ALL=(root) NOPASSWD: /usr/bin/install -o root -g root -m 0755 /var/lib/gitea_ci/neuron-dev /usr/bin/neuron

View File

@@ -11,14 +11,14 @@ metrics_listen = "0.0.0.0:31314"
[eviction]
strategy = "lru"
# Restart neurons after this many load/unload cycles to defragment VRAM.
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
# Set to 0 to disable.
defrag_after_cycles = 50
# -- Nodes ---------------------------------------------------------------
# Each [[nodes]] entry declares a neuron daemon in the fleet.
# Models are discovered by polling the neuron's /models endpoint.
# Pinned models (see models.toml) are never evicted.
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
# Models are discovered by polling the node's /v1/models endpoint.
# Pinned models are never evicted.
[[nodes]]
name = "gpu-large"

View File

@@ -1,10 +1,10 @@
Name: cortex
Version: 0.1.16
Version: 0.1.12
Release: 1%{?dist}
Summary: Inference gateway for multi-node GPU clusters
License: GPL-3.0-or-later
URL: https://git.lair.cafe/helexa/helexa
URL: https://git.lair.cafe/helexa/cortex
Source0: %{name}-%{version}.tar.gz
Source1: %{name}-%{version}-vendor.tar.gz
@@ -21,7 +21,6 @@ BuildRequires: systemd-rpm-macros
Requires(pre): shadow-utils
Requires: systemd
Requires: firewalld-filesystem
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
# from our .service file and emits Requires: user(cortex)/group(cortex).
@@ -57,7 +56,6 @@ cargo build --release -p cortex-cli
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
install -Dm644 data/cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
install -dm755 %{buildroot}%{_sysconfdir}/cortex
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
@@ -74,53 +72,16 @@ install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
%postun
%systemd_postun_with_restart cortex.service
%posttrans
# Migration: older cortex packages shipped the firewalld service as
# `helexa-cortex` and (in some build streams) with wrong port numbers
# (9301/9302/9304). Operators who enabled that legacy service in their
# zone end up with the wrong-port override taking precedence over the
# vendor `cortex.xml` now in /usr/lib/firewalld/services/. Clean up the
# stale /etc/ override here and migrate any zone bindings to the new
# service name.
if [ -f /etc/firewalld/services/helexa-cortex.xml ]; then
rm -f /etc/firewalld/services/helexa-cortex.xml
fi
if [ -x /usr/bin/firewall-cmd ] && /usr/bin/firewall-cmd --state >/dev/null 2>&1; then
# Drop the legacy service name from every zone where it was enabled
# and add the new `cortex` service in its place. Operators who never
# ran firewall-cmd against either name see no zone change.
for zone in $(/usr/bin/firewall-cmd --get-active-zones 2>/dev/null \
| awk '!/^[[:space:]]/ {print $1}'); do
if /usr/bin/firewall-cmd --permanent --zone="$zone" --query-service=helexa-cortex >/dev/null 2>&1; then
/usr/bin/firewall-cmd --permanent --zone="$zone" --remove-service=helexa-cortex >/dev/null 2>&1 || :
/usr/bin/firewall-cmd --permanent --zone="$zone" --add-service=cortex >/dev/null 2>&1 || :
fi
done
/usr/bin/firewall-cmd --reload >/dev/null 2>&1 || :
fi
:
%files
%license LICENSE
%doc README.md
%{_bindir}/cortex
%{_unitdir}/cortex.service
%{_sysusersdir}/cortex.conf
%{_prefix}/lib/firewalld/services/cortex.xml
%dir %{_sysconfdir}/cortex
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
%config(noreplace) %{_sysconfdir}/cortex/models.toml
%changelog
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
- chore: ignore local deploy script
- chore: move default ports out of common-collision ranges
- ci: drop actions/cache for cargo registry and target
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
- ci: publish both packages to a single helexa/helexa COPR project
- fix(rpm): rename neuron package to helexa-neuron
- ci: commit generated %changelog entries back to main
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
- Initial package

View File

@@ -5,7 +5,7 @@ use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(name = "cortex")]
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
#[command(version)]
struct Cli {
#[command(subcommand)]

View File

@@ -2,7 +2,7 @@
//!
//! These mirror the `/v1/messages` format used by the Anthropic API.
//! The gateway accepts these, translates to OpenAI format, proxies to
//! the inference backend (neuron), then translates the response back.
//! mistral.rs, then translates the response back.
use serde::{Deserialize, Serialize};
use serde_json::Value;

View File

@@ -1,119 +0,0 @@
//! Build/version metadata shared between cortex and neuron.
//!
//! neuron captures these facts at compile time in its `build.rs`
//! (git SHA, enabled cargo features, rustc/candle versions, …) and
//! serves them from `GET /version`. cortex and `helexa-bench`
//! deserialize the same struct so a benchmark run can be attributed to
//! the exact daemon build that produced it — not just the host's CUDA
//! and driver versions that `/discovery` already reports.
//!
//! Every field beyond the always-present package version is
//! `#[serde(default)]` so a newer reader stays compatible with an
//! older neuron that omits a field (and vice versa) — the same
//! forward/backward-compat discipline as
//! [`crate::discovery::ActivationStatus`].
use serde::{Deserialize, Serialize};
/// Build-time identity of a neuron daemon.
///
/// Returned by `GET /version`. The `git_sha` is the canonical "which
/// build is live" key — benchmark records are bucketed by it, so a
/// regression can be pinned to a daemon change rather than a host
/// change. When neuron is built from a source tarball with no git
/// metadata available (and no `HELEXA_BUILD_SHA` injected by CI/RPM),
/// `git_sha` is the string `"unknown"`.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct BuildInfo {
/// Crate version from `CARGO_PKG_VERSION` (e.g. `"0.1.16"`).
pub package_version: String,
/// Short git SHA, or `"unknown"` when unavailable at build time.
#[serde(default = "unknown")]
pub git_sha: String,
/// Full 40-char git SHA when available.
#[serde(default)]
pub git_sha_long: Option<String>,
/// Whether the working tree had uncommitted changes at build time.
/// `false` when the SHA is unknown (tarball build).
#[serde(default)]
pub git_dirty: bool,
/// RFC3339 build timestamp.
#[serde(default)]
pub build_timestamp: Option<String>,
/// `rustc --version` output of the compiler used.
#[serde(default)]
pub rustc_version: Option<String>,
/// Cargo build profile: `"release"` or `"debug"`.
#[serde(default)]
pub profile: Option<String>,
/// Target triple the binary was compiled for.
#[serde(default)]
pub target: Option<String>,
/// Enabled cargo features (e.g. `["cuda", "cudnn"]`). These define
/// the performance envelope, so they are recorded against every
/// benchmark run.
#[serde(default)]
pub features: Vec<String>,
/// Locked `candle-core` version, best-effort from `Cargo.lock`.
#[serde(default)]
pub candle_version: Option<String>,
}
fn unknown() -> String {
"unknown".to_string()
}
impl BuildInfo {
/// A placeholder used by non-neuron benchmark targets (and tests)
/// that have no build metadata to report.
pub fn unknown() -> Self {
BuildInfo {
package_version: env!("CARGO_PKG_VERSION").to_string(),
git_sha: unknown(),
git_sha_long: None,
git_dirty: false,
build_timestamp: None,
rustc_version: None,
profile: None,
target: None,
features: Vec::new(),
candle_version: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trips_full() {
let info = BuildInfo {
package_version: "0.1.16".into(),
git_sha: "30d50d6".into(),
git_sha_long: Some("30d50d6abc123".into()),
git_dirty: true,
build_timestamp: Some("2026-06-13T10:00:00+00:00".into()),
rustc_version: Some("rustc 1.85.0".into()),
profile: Some("release".into()),
target: Some("x86_64-unknown-linux-gnu".into()),
features: vec!["cuda".into(), "cudnn".into()],
candle_version: Some("0.10.2".into()),
};
let json = serde_json::to_string(&info).unwrap();
let back: BuildInfo = serde_json::from_str(&json).unwrap();
assert_eq!(info, back);
}
#[test]
fn deserializes_minimal_payload() {
// An older neuron might send only the package version; every
// other field must default rather than fail.
let back: BuildInfo = serde_json::from_str(r#"{"package_version":"0.1.0"}"#).unwrap();
assert_eq!(back.package_version, "0.1.0");
assert_eq!(back.git_sha, "unknown");
assert!(!back.git_dirty);
assert!(back.features.is_empty());
assert!(back.candle_version.is_none());
}
}

View File

@@ -1,8 +1,6 @@
//! Model catalogue — profiles describing how to serve each model.
use crate::discovery::DeviceInfo;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
/// A model serving profile loaded from models.toml.
@@ -24,17 +22,6 @@ pub struct ModelProfile {
/// Neurons where this model should never be evicted.
#[serde(default)]
pub pinned_on: Vec<String>,
/// Source scheme this profile's weights come from. When set, the
/// router prefixes `id` with `scheme:` before forwarding the load
/// request to neuron, ensuring the daemon fetches from the right
/// registry regardless of which entry happens to match `id`.
///
/// `None` lets neuron substitute its own `default_source` (typically
/// `huggingface`). Set to `"helexa"` when the model is hosted in
/// the helexa registry — operator-procurement-grade audit relies
/// on this being explicit per model rather than implicit.
#[serde(default)]
pub source: Option<String>,
}
fn default_min_devices() -> u32 {
@@ -46,14 +33,6 @@ fn default_min_devices() -> u32 {
pub struct ModelCatalogue {
#[serde(default)]
pub models: Vec<ModelProfile>,
/// Tier aliases — clients can send a request with `model: "helexa/small"`
/// and the gateway transparently rewrites + routes to the concrete
/// model id this maps to. Lets operators define latency/quality
/// tiers (`small`/`balanced`/`large`, `fast`/`thinking`, etc.)
/// without imposing knowledge of specific model ids on clients.
/// Loaded from the `[aliases]` table in models.toml.
#[serde(default)]
pub aliases: HashMap<String, String>,
}
impl ModelCatalogue {
@@ -85,162 +64,4 @@ impl ModelCatalogue {
.iter()
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
}
/// Find a profile by model id.
pub fn get(&self, model_id: &str) -> Option<&ModelProfile> {
self.models.iter().find(|p| p.id == model_id)
}
/// Resolve an alias to its concrete model id. Returns `id` verbatim
/// when it isn't an alias. Aliases never chain — operator config
/// is treated as flat — so this is a single lookup.
pub fn resolve_alias<'a>(&'a self, id: &'a str) -> &'a str {
self.aliases.get(id).map(String::as_str).unwrap_or(id)
}
}
impl ModelProfile {
/// True iff this profile's placement constraints can be satisfied
/// by the named neuron with the given device topology.
///
/// Constraints checked:
/// - `pinned_on`: non-empty → neuron must be on the list.
/// - `min_devices`: neuron must have at least this many devices.
/// - `min_device_vram_mb`: at least `min_devices` of the neuron's
/// devices must each meet this VRAM floor.
pub fn is_feasible_on(&self, neuron_name: &str, devices: &[DeviceInfo]) -> bool {
if !self.pinned_on.is_empty() && !self.pinned_on.iter().any(|n| n == neuron_name) {
return false;
}
if (devices.len() as u32) < self.min_devices {
return false;
}
if let Some(min_vram) = self.min_device_vram_mb {
let big_enough = devices
.iter()
.filter(|d| d.vram_total_mb >= min_vram)
.count() as u32;
if big_enough < self.min_devices {
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::DeviceInfo;
fn device(idx: u32, vram_mb: u64) -> DeviceInfo {
DeviceInfo {
index: idx,
name: format!("DEV-{idx}"),
vram_total_mb: vram_mb,
compute_capability: "8.6".into(),
}
}
fn profile() -> ModelProfile {
ModelProfile {
id: "Qwen/Qwen3.6-27B".into(),
harness: "candle".into(),
quant: None,
vram_mb: Some(45_000),
min_devices: 2,
min_device_vram_mb: Some(24_000),
pinned_on: vec![],
source: None,
}
}
#[test]
fn feasible_when_two_devices_meet_vram_floor() {
let p = profile();
let devices = [device(0, 32_000), device(1, 32_000)];
assert!(p.is_feasible_on("beast", &devices));
}
#[test]
fn infeasible_when_only_one_device() {
let p = profile();
let devices = [device(0, 64_000)];
assert!(!p.is_feasible_on("benjy", &devices));
}
#[test]
fn infeasible_when_one_device_underspec() {
let p = profile();
let devices = [device(0, 32_000), device(1, 12_000)];
assert!(!p.is_feasible_on("mixed", &devices));
}
#[test]
fn pinned_on_excludes_other_neurons() {
let mut p = profile();
p.pinned_on = vec!["beast".into()];
let devices = [device(0, 32_000), device(1, 32_000)];
assert!(p.is_feasible_on("beast", &devices));
assert!(!p.is_feasible_on("benjy", &devices));
}
#[test]
fn no_vram_floor_just_needs_min_devices() {
let mut p = profile();
p.min_device_vram_mb = None;
let devices = [device(0, 1_000), device(1, 1_000)];
assert!(p.is_feasible_on("anywhere", &devices));
}
#[test]
fn resolve_alias_returns_target_when_alias_present() {
let mut cat = ModelCatalogue::default();
cat.aliases
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
}
#[test]
fn resolve_alias_passes_through_when_not_an_alias() {
let mut cat = ModelCatalogue::default();
cat.aliases
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
assert_eq!(cat.resolve_alias("Qwen/Qwen3-8B"), "Qwen/Qwen3-8B");
}
#[test]
fn source_defaults_to_none_when_absent_from_toml() {
let src = r#"
[[models]]
id = "Qwen/Qwen3-30B"
harness = "candle"
"#;
let cat: ModelCatalogue = toml::from_str(src).expect("parse models table");
assert!(cat.models[0].source.is_none());
}
#[test]
fn source_round_trips_through_toml() {
let src = r#"
[[models]]
id = "Helexa/Qwen3.6-27B-Uncensored"
harness = "candle"
source = "helexa"
"#;
let cat: ModelCatalogue = toml::from_str(src).expect("parse models table");
assert_eq!(cat.models[0].source.as_deref(), Some("helexa"));
}
#[test]
fn aliases_table_round_trips_through_toml() {
let src = r#"
[aliases]
"helexa/small" = "Qwen/Qwen3-1.7B"
"helexa/large" = "Qwen/Qwen3.6-27B"
"#;
let cat: ModelCatalogue = toml::from_str(src).expect("parse aliases table");
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
assert_eq!(cat.resolve_alias("helexa/large"), "Qwen/Qwen3.6-27B");
}
}

View File

@@ -22,17 +22,6 @@ pub struct DiscoveryResponse {
pub driver_version: Option<String>,
pub devices: Vec<DeviceInfo>,
pub harnesses: Vec<String>,
/// Set when the host has an NVIDIA stack that is currently
/// unusable — specifically the userspace↔kernel-module version
/// skew after an un-rebooted driver update ("Driver/library
/// version mismatch"), where every CUDA call including nvidia-smi
/// fails (#19). `None` on healthy hosts AND on hosts with no
/// NVIDIA stack at all (CPU-only is not an error). Carries an
/// operator-actionable description; cortex can read it to route
/// around the node instead of cold-loading into a guaranteed
/// failure.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cuda_unavailable_reason: Option<String>,
}
/// Runtime health metrics for a single GPU device.
@@ -47,72 +36,8 @@ pub struct DeviceHealth {
/// Runtime health response from a neuron endpoint.
/// Returned by `GET /health`.
///
/// `activation` was added in 2026-05-26 to distinguish "process is up
/// and reachable" from "process is ready to serve traffic". A `Type=simple`
/// systemd unit reports `active` the moment the binary starts — but a
/// neuron whose `default_models` list takes minutes to materialise
/// won't bind its listener (or, in the new flow, won't have any models
/// loaded) until pre-warm completes. The new field is `#[serde(default)]`
/// so a pre-2026-05-26 gateway polling a new neuron — or vice versa —
/// keeps working.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthResponse {
pub uptime_secs: u64,
pub devices: Vec<DeviceHealth>,
#[serde(default)]
pub activation: ActivationStatus,
}
/// High-level activation state of the neuron daemon. The HTTP listener
/// is bound during both states; what differs is whether the configured
/// `default_models` have finished loading.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ActivationState {
/// At least one `default_models` entry is still loading. The
/// neuron's other endpoints work, but inference against
/// not-yet-loaded models will 404.
PreWarming,
/// Every `default_models` entry has either loaded or failed; the
/// neuron is steady-state. Subsequent on-demand loads via
/// `/models/load` don't flip back to PreWarming — that field
/// reflects the activation-time set only.
#[default]
Ready,
}
/// Per-model failure record surfaced in [`ActivationStatus::failed`].
/// The error string is the rendered anyhow chain at the time of the
/// failure; operators read it from `/health` to decide whether to
/// retry, edit the spec, or unload+reload.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreWarmFailure {
pub model_id: String,
pub error: String,
}
/// Activation-time progress snapshot. All four lists are populated by
/// the neuron's pre-warm task and read by the `/health` handler. The
/// snapshot is consistent: a model id appears in exactly one of
/// `pending`, `in_progress` (as `Option<String>`), `completed`, or
/// `failed` at any point in time.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ActivationStatus {
pub state: ActivationState,
/// Model ids queued but not yet started. Empty in `Ready` state.
#[serde(default)]
pub pending: Vec<String>,
/// Model id currently materialising. None when between models or
/// in `Ready` state.
#[serde(default)]
pub in_progress: Option<String>,
/// Model ids that finished loading successfully during this
/// activation. Cleared on process restart.
#[serde(default)]
pub completed: Vec<String>,
/// Model ids that failed during this activation, with the rendered
/// error chain. Cleared on process restart.
#[serde(default)]
pub failed: Vec<PreWarmFailure>,
}

View File

@@ -9,13 +9,13 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
/// Configuration for a harness instance on a neuron.
///
/// All current harnesses are in-process (candle); per-harness tuning
/// (cache paths, device policies, etc.) lives in dedicated config
/// blocks rather than on this struct.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HarnessConfig {
pub name: String,
/// Base URL of the harness (e.g. "http://localhost:8080" for mistral.rs).
pub endpoint: Option<String>,
/// Systemd unit name, if the harness is managed via systemd.
pub systemd_unit: Option<String>,
}
/// Health status of a harness process.
@@ -44,37 +44,19 @@ pub struct ModelInfo {
pub status: String,
pub devices: Vec<u32>,
pub vram_used_mb: Option<u64>,
/// Modalities this loaded model supports. Today: `["text"]` for
/// text-only checkpoints, `["text", "vision"]` for vision-capable
/// ones (Stage B7 of the vision plan). Clients like litellm /
/// agent0 can gate `image_url` submission on the advertised set.
///
/// Optional in the wire format so older clients that don't read
/// it stay compatible. Default-empty for absent/older data, which
/// callers can interpret as "text".
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub capabilities: Vec<String>,
}
/// What an inference harness must do, from neuron's perspective.
///
/// All current harnesses are in-process — they share neuron's address
/// space and lifecycle. `start`/`stop` therefore default to no-ops; a
/// future process-supervising harness would override them.
#[async_trait]
pub trait Harness: Send + Sync {
/// Human-readable name (e.g. "candle").
/// Human-readable name (e.g. "mistralrs", "llamacpp", "comfyui").
fn name(&self) -> &str;
/// Start the harness. Default no-op for in-process harnesses.
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
Ok(())
}
/// Start the harness process if it is not already running.
async fn start(&self, config: &HarnessConfig) -> Result<()>;
/// Stop the harness. Default no-op for in-process harnesses.
async fn stop(&self) -> Result<()> {
Ok(())
}
/// Stop the harness process gracefully.
async fn stop(&self) -> Result<()>;
/// Health check. Returns the harness process status.
async fn health(&self) -> HarnessHealth;

View File

@@ -1,5 +1,4 @@
pub mod anthropic;
pub mod build_info;
pub mod catalogue;
pub mod config;
pub mod discovery;
@@ -7,6 +6,4 @@ pub mod harness;
pub mod metrics;
pub mod node;
pub mod openai;
pub mod responses;
pub mod source;
pub mod translate;

View File

@@ -1,4 +1,3 @@
use crate::discovery::{ActivationStatus, DiscoveryResponse};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@@ -14,18 +13,6 @@ pub struct NodeState {
/// Number of load/unload cycles since last process restart.
pub lifecycle_cycles: u32,
pub last_poll: Option<DateTime<Utc>>,
/// Result of the most recent successful `GET /discovery` against
/// this neuron. Cached forever once obtained — device topology is
/// invariant for a given neuron process. `None` until the first
/// successful poll. Used by the router and `/v1/models` to do
/// catalogue × topology feasibility checks.
pub discovery: Option<DiscoveryResponse>,
/// Last-seen pre-warm progress from this neuron's `/health`
/// endpoint. `None` until the first /health poll succeeds. The
/// `/v1/models` handler reads `in_progress` + `pending` from here
/// to synthesize `Loading` locations so clients see a catalogued
/// model that's mid-prewarm as "loading", not "missing".
pub activation: Option<ActivationStatus>,
}
/// A model registered on a node, with its runtime status.
@@ -37,72 +24,25 @@ pub struct ModelEntry {
pub last_accessed: Option<DateTime<Utc>>,
/// Estimated VRAM usage in MB when loaded.
pub vram_estimate_mb: Option<u64>,
/// Modalities the loaded model advertises (e.g. `["text", "vision"]`),
/// copied verbatim from the neuron's `ModelInfo.capabilities` at poll
/// time. Empty when the neuron reports none. `#[serde(default)]` keeps
/// older persisted/serialised entries deserialisable.
#[serde(default)]
pub capabilities: Vec<String>,
}
/// Model lifecycle status.
///
/// `Loading` is a gateway-side synthetic status: neurons never emit it
/// on `/models` (that endpoint only knows about already-loaded handles).
/// The gateway populates it from a neuron's `/health` activation
/// snapshot so the unified `/v1/models` can distinguish "model is
/// catalogued but no one has it" from "model is materialising on
/// neuron N right now". Other status values are reported verbatim by
/// neurons.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelStatus {
Loaded,
Unloaded,
Reloading,
Loading,
/// Reported by neuron while a poisoned model auto-recovers via
/// unload→reload (#17/#20). Temporarily unservable but NOT
/// evicted: the gateway holds the route, answers with a transient
/// retry error instead of 404, and must not race a second
/// placement elsewhere.
Recovering,
}
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
///
/// The first four fields (`id`, `object`, `created`, `owned_by`) match
/// OpenAI's `/v1/models` shape verbatim, so existing OpenAI-aware
/// tooling deserialises this without custom code. The remaining fields
/// are helexa-specific extensions — OpenAI clients ignore unknown
/// fields and other consumers can read them for placement / debugging.
/// Includes which node(s) host this model and their status.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CortexModelEntry {
pub id: String,
/// Always `"model"` per OpenAI's contract.
pub object: String,
/// Unix-second timestamp; cortex stamps this at response time.
pub created: u64,
/// OpenAI's "publisher" field — `"helexa"` for everything we serve.
pub owned_by: String,
/// True if any neuron currently has this model loaded. False for
/// catalogue entries that are feasible but not yet loaded.
pub loaded: bool,
/// Neurons whose discovered topology can satisfy this model's
/// catalogue placement constraints. Empty for models that are
/// loaded somewhere but not present in the catalogue (cortex has
/// no feasibility opinion on those).
pub feasible_on: Vec<String>,
/// Where this model is actually loaded right now. Subset of (or
/// disjoint from) `feasible_on` depending on whether the catalogue
/// covers this model.
/// Which nodes have this model (and their status).
pub locations: Vec<ModelLocation>,
/// Union of the modalities advertised by every neuron that has this
/// model loaded (e.g. `["text", "vision"]`). Empty for catalogue-only
/// entries with no loaded location — the catalogue profile doesn't
/// declare capabilities yet (tracked separately from C3).
#[serde(default)]
pub capabilities: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -3,7 +3,7 @@
//! These are a subset sufficient for chat completions (streaming + non-streaming).
//! Fields not relevant to proxying are captured as `serde_json::Value` via
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
//! extension field a backend might support.
//! extension field mistral.rs supports.
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -22,7 +22,7 @@ pub struct ChatCompletionRequest {
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
/// All other fields (tools, response_format, backend extensions, etc.)
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
#[serde(flatten)]
pub extra: Value,
}
@@ -71,18 +71,10 @@ pub struct ChatCompletionChoice {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
#[serde(default)]
pub id: String,
#[serde(default)]
pub object: String,
#[serde(default)]
pub created: u64,
// Lenient deserialization throughout: the gateway parses chunks
// from arbitrary OpenAI-compatible upstreams, and some engines
// omit fields on special frames (e.g. usage-only final chunks).
#[serde(default)]
pub model: String,
#[serde(default)]
pub choices: Vec<ChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,

View File

@@ -1,346 +0,0 @@
//! OpenAI Responses API (`POST /v1/responses`) envelope types.
//!
//! This is OpenAI's newer chat surface, distinct from
//! `/v1/chat/completions` in three ways that matter for us:
//!
//! 1. **Input shape**. Instead of a `messages` array, the request
//! carries `input` — either a plain string (single user turn)
//! or an array of typed items (messages, function calls,
//! function-call outputs, reasoning blocks, …).
//! 2. **Output shape**. The response carries a single `output`
//! array of items, each typed. We always emit one
//! `OutputItem::Message` containing the assistant's reply (plus,
//! when we get there, separate `function_call` items).
//! 3. **Streaming events**. Where chat completions stream
//! structurally-identical `chat.completion.chunk` frames over
//! `data:` lines, Responses streams *named* events
//! (`response.created`, `response.output_text.delta`,
//! `response.completed`, …) over `event:` + `data:` SSE pairs.
//! The wire projector in `neuron::wire::openai_responses` builds
//! these from the same [`crate::openai`]-shaped
//! `InferenceEvent` stream the chat projector consumes.
//!
//! Scope cuts for this first cut:
//!
//! - **`previous_response_id` is rejected at parse time**. Stateful
//! chained conversations need a persistence layer we don't have.
//! - **Reasoning items are accepted-and-ignored** (no Qwen3
//! `<think>` routing yet). Audio and embedded resources are
//! rejected as unsupported.
//! - **Tool calls** (function_call / function_call_output) are
//! carried as round-trip types but the candle harness doesn't
//! emit them yet — wired so the surface is in place for the
//! day we add proper tool-call extraction.
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ── Request ──────────────────────────────────────────────────────────
/// Body of a `POST /v1/responses` request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponsesRequest {
pub model: String,
pub input: ResponsesInput,
/// System-prompt-style instructions. The Responses API
/// separates these from input so a caller doesn't have to
/// build a `system` message item by hand.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(default)]
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
/// Chained-conversation identifier. We don't store responses
/// server-side yet; if this is `Some`, the handler returns 400.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
/// Catch-all for anything we don't model yet (tools, tool_choice,
/// reasoning, response_format, …). Lets a client send a
/// forward-compatible request without our parser rejecting it.
#[serde(flatten)]
pub extra: Value,
}
/// `input` is either a single string or an array of typed items.
/// `#[serde(untagged)]` so the wire shape `"input": "hi"` and
/// `"input": [{...}]` both deserialize.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ResponsesInput {
Text(String),
Items(Vec<ResponsesInputItem>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponsesInputItem {
/// A user / assistant / system turn.
Message {
role: String,
content: ResponsesMessageContent,
},
/// Assistant emitted a tool call. Round-trip only — neuron
/// doesn't synthesise these yet.
FunctionCall {
call_id: String,
name: String,
arguments: String,
},
/// User is feeding a tool result back into the model.
FunctionCallOutput { call_id: String, output: String },
/// Reasoning items emitted by o-series models. Accepted but
/// not forwarded to the model — neuron's candle path doesn't
/// surface reasoning separately yet.
Reasoning {
#[serde(default)]
content: Vec<Value>,
},
}
/// Inside a `Message` item, content is either a plain string or an
/// array of typed parts. Mirrors the chat-completions Parts shape.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ResponsesMessageContent {
Text(String),
Parts(Vec<ResponsesContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponsesContentPart {
/// Plain text inside a user / system turn.
InputText { text: String },
/// An image. `image_url` is either a remote URL or a
/// `data:image/png;base64,…` URI; the request translator just
/// forwards the string.
InputImage {
image_url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
detail: Option<String>,
},
/// Returned text inside an assistant turn — only relevant when
/// the caller is feeding an assistant turn back in to continue
/// a conversation manually (no `previous_response_id`).
OutputText {
text: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
annotations: Vec<Value>,
},
}
// ── Response (non-streaming) ─────────────────────────────────────────
/// Body of a `POST /v1/responses` response.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponsesResponse {
pub id: String,
/// Always `"response"`.
pub object: String,
pub created_at: u64,
/// `"completed"`, `"incomplete"`, or — for the initial event of
/// a streaming response — `"in_progress"`.
pub status: String,
pub model: String,
pub output: Vec<ResponsesOutputItem>,
/// Populated on completion; `None` while streaming.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage: Option<ResponsesUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponsesOutputItem {
Message {
id: String,
/// Always `"assistant"` for model output.
role: String,
/// Output content parts. We always emit a single
/// `OutputText` today; multi-part output would land here
/// once we have e.g. image generation.
content: Vec<ResponsesOutputContent>,
/// Item-level status. `"in_progress"` while streaming the
/// content parts, `"completed"` when done.
#[serde(default = "default_item_status")]
status: String,
},
/// Reserved for the day tool-call extraction lands. The wire
/// shape mirrors `ResponsesInputItem::FunctionCall`.
FunctionCall {
id: String,
call_id: String,
name: String,
arguments: String,
#[serde(default = "default_item_status")]
status: String,
},
}
fn default_item_status() -> String {
"completed".into()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponsesOutputContent {
OutputText {
text: String,
/// Citations / inline annotations. Empty today; reserved
/// for the day we wire in web search / file search.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
annotations: Vec<Value>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponsesUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
}
// ── Streaming event names ────────────────────────────────────────────
/// Event names the SSE projector emits, hoisted as constants so
/// the projector and the wire shape stay in sync without
/// string-typos. The strings are dictated by OpenAI's published
/// Responses API.
pub mod events {
pub const CREATED: &str = "response.created";
/// Fired between `response.created` and the first output-item
/// event. Marks "request validated, model is generating" —
/// some clients use it to differentiate the "warming up" state
/// from "streaming tokens" in their UI.
pub const IN_PROGRESS: &str = "response.in_progress";
pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added";
pub const CONTENT_PART_ADDED: &str = "response.content_part.added";
pub const OUTPUT_TEXT_DELTA: &str = "response.output_text.delta";
pub const OUTPUT_TEXT_DONE: &str = "response.output_text.done";
pub const CONTENT_PART_DONE: &str = "response.content_part.done";
pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done";
pub const COMPLETED: &str = "response.completed";
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialises_input_string_form() {
let raw = r#"{"model": "m", "input": "hello"}"#;
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
match req.input {
ResponsesInput::Text(s) => assert_eq!(s, "hello"),
other => panic!("expected Text, got {other:?}"),
}
}
#[test]
fn deserialises_input_items_form() {
let raw = r#"{
"model": "m",
"input": [
{"type": "message", "role": "user", "content": "hi"}
]
}"#;
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
match req.input {
ResponsesInput::Items(items) => {
assert_eq!(items.len(), 1);
match &items[0] {
ResponsesInputItem::Message { role, content } => {
assert_eq!(role, "user");
match content {
ResponsesMessageContent::Text(t) => assert_eq!(t, "hi"),
other => panic!("expected Text content, got {other:?}"),
}
}
other => panic!("expected Message item, got {other:?}"),
}
}
other => panic!("expected Items, got {other:?}"),
}
}
#[test]
fn deserialises_input_with_image() {
let raw = r#"{
"model": "m",
"input": [
{"type": "message", "role": "user", "content": [
{"type": "input_text", "text": "what is this"},
{"type": "input_image", "image_url": "data:image/png;base64,AAA="}
]}
]
}"#;
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
let items = match req.input {
ResponsesInput::Items(i) => i,
other => panic!("expected Items, got {other:?}"),
};
let parts = match &items[0] {
ResponsesInputItem::Message {
content: ResponsesMessageContent::Parts(p),
..
} => p,
other => panic!("expected Parts, got {other:?}"),
};
assert_eq!(parts.len(), 2);
assert!(matches!(
&parts[0],
ResponsesContentPart::InputText { text } if text == "what is this"
));
assert!(matches!(
&parts[1],
ResponsesContentPart::InputImage { image_url, .. }
if image_url == "data:image/png;base64,AAA="
));
}
#[test]
fn unknown_fields_round_trip_via_extra() {
let raw = r#"{
"model": "m",
"input": "hi",
"tools": [{"type": "web_search"}],
"reasoning": {"effort": "medium"}
}"#;
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
assert!(req.extra.get("tools").is_some());
assert!(req.extra.get("reasoning").is_some());
}
#[test]
fn response_round_trips_through_serde() {
let r = ResponsesResponse {
id: "resp_1".into(),
object: "response".into(),
created_at: 1700,
status: "completed".into(),
model: "m".into(),
output: vec![ResponsesOutputItem::Message {
id: "msg_1".into(),
role: "assistant".into(),
content: vec![ResponsesOutputContent::OutputText {
text: "hi there".into(),
annotations: vec![],
}],
status: "completed".into(),
}],
usage: Some(ResponsesUsage {
input_tokens: 5,
output_tokens: 3,
total_tokens: 8,
}),
};
let json = serde_json::to_string(&r).unwrap();
let parsed: ResponsesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "resp_1");
assert_eq!(parsed.output.len(), 1);
}
}

View File

@@ -1,267 +0,0 @@
//! Scheme-qualified model identifiers.
//!
//! cortex/neuron historically resolves every model id through hf-hub
//! against `https://huggingface.co`. Helexa is adding an EU-hosted
//! registry (`registry.helexa.ai`) alongside HF — both speak the same
//! HF-compatible wire format, but the bytes, jurisdiction, and trust
//! root differ. Model ids therefore need a scheme:
//!
//! - `huggingface:Qwen/Qwen3.6-27B` — HF-hosted bytes
//! - `helexa:Qwen/Qwen3.6-27B-Uncensored` — helexa registry bytes
//! - `helexa:SomeOperator/CustomFinetune` — operator publishing
//! under the helexa namespace; same scheme handles all `org/name`
//! pairs hosted in that registry.
//!
//! Bare `org/name` parses with an empty scheme; the caller (typically
//! a harness) substitutes its configured default scheme so existing
//! configs keep working through the transition.
use serde::{Deserialize, Serialize};
use std::fmt;
use std::str::FromStr;
/// Parsed `scheme:org/name`. Bare `org/name` produces an empty scheme
/// — call `with_default_scheme` (or check `is_scheme_unset`) to
/// resolve before using.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelSourceId {
pub scheme: String,
pub org: String,
pub name: String,
}
/// Errors from `ModelSourceId::from_str`. Carries the offending input
/// so log lines / API errors can echo what the operator typed.
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ParseError {
#[error("empty model id")]
Empty,
#[error("model id '{0}' is missing the '/' between org and name")]
MissingSlash(String),
#[error("model id '{0}' has an empty scheme before ':'")]
EmptyScheme(String),
#[error("model id '{0}' has an empty org")]
EmptyOrg(String),
#[error("model id '{0}' has an empty name")]
EmptyName(String),
#[error("model id '{0}' has a scheme containing '/' which is reserved for org/name")]
SchemeContainsSlash(String),
#[error("model id '{0}' has a name containing ':' which is reserved for the scheme prefix")]
NameContainsColon(String),
}
impl ModelSourceId {
/// Construct directly from already-validated parts. Used by tests
/// and call sites that have the fields separately; the public API
/// for parsing user input is `FromStr`.
pub fn new(scheme: impl Into<String>, org: impl Into<String>, name: impl Into<String>) -> Self {
Self {
scheme: scheme.into(),
org: org.into(),
name: name.into(),
}
}
/// True when this id parsed from a bare `org/name` (no scheme
/// prefix). The harness substitutes its configured default in
/// `with_default_scheme` before resolving against a registry.
pub fn is_scheme_unset(&self) -> bool {
self.scheme.is_empty()
}
/// Substitute `default` for an empty scheme. No-op when the scheme
/// is already set. Returns self by value so it composes neatly:
/// `id.parse::<ModelSourceId>()?.with_default_scheme("huggingface")`.
pub fn with_default_scheme(mut self, default: &str) -> Self {
if self.scheme.is_empty() {
self.scheme = default.to_string();
}
self
}
/// The `org/name` half — what an hf-hub `Api::model(...)` call
/// expects regardless of which scheme/endpoint we're hitting.
pub fn repo_path(&self) -> String {
format!("{}/{}", self.org, self.name)
}
}
impl fmt::Display for ModelSourceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.scheme.is_empty() {
write!(f, "{}/{}", self.org, self.name)
} else {
write!(f, "{}:{}/{}", self.scheme, self.org, self.name)
}
}
}
impl FromStr for ModelSourceId {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
return Err(ParseError::Empty);
}
// Scheme split. Only the *first* colon counts — anything after
// belongs to org/name (and would be rejected separately because
// `:` isn't allowed there).
let (scheme, rest) = match s.split_once(':') {
Some((scheme, rest)) => {
if scheme.is_empty() {
return Err(ParseError::EmptyScheme(s.to_string()));
}
if scheme.contains('/') {
return Err(ParseError::SchemeContainsSlash(s.to_string()));
}
(scheme.to_string(), rest)
}
None => (String::new(), s),
};
let (org, name) = rest
.split_once('/')
.ok_or_else(|| ParseError::MissingSlash(s.to_string()))?;
if org.is_empty() {
return Err(ParseError::EmptyOrg(s.to_string()));
}
if name.is_empty() {
return Err(ParseError::EmptyName(s.to_string()));
}
if name.contains(':') {
return Err(ParseError::NameContainsColon(s.to_string()));
}
Ok(Self {
scheme,
org: org.to_string(),
name: name.to_string(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_qualified() {
let id: ModelSourceId = "huggingface:Qwen/Qwen3.6-27B".parse().unwrap();
assert_eq!(id.scheme, "huggingface");
assert_eq!(id.org, "Qwen");
assert_eq!(id.name, "Qwen3.6-27B");
assert_eq!(id.repo_path(), "Qwen/Qwen3.6-27B");
assert!(!id.is_scheme_unset());
}
#[test]
fn parses_helexa_scheme() {
let id: ModelSourceId = "helexa:SomeOperator/Qwen3.6-27B-Uncensored"
.parse()
.unwrap();
assert_eq!(id.scheme, "helexa");
assert_eq!(id.org, "SomeOperator");
assert_eq!(id.name, "Qwen3.6-27B-Uncensored");
}
#[test]
fn parses_bare_id_with_empty_scheme() {
let id: ModelSourceId = "Qwen/Qwen3-30B-A3B-Instruct".parse().unwrap();
assert_eq!(id.scheme, "");
assert_eq!(id.org, "Qwen");
assert_eq!(id.name, "Qwen3-30B-A3B-Instruct");
assert!(id.is_scheme_unset());
}
#[test]
fn substitutes_default_scheme_only_when_unset() {
let id: ModelSourceId = "Qwen/Q3".parse().unwrap();
assert_eq!(id.with_default_scheme("huggingface").scheme, "huggingface");
let id: ModelSourceId = "helexa:Qwen/Q3".parse().unwrap();
assert_eq!(
id.with_default_scheme("huggingface").scheme,
"helexa",
"default substitution must not override an explicit scheme"
);
}
#[test]
fn display_roundtrips_qualified_id() {
let s = "helexa:Helexa/Qwen3.6-27B";
let id: ModelSourceId = s.parse().unwrap();
assert_eq!(id.to_string(), s);
}
#[test]
fn display_roundtrips_bare_id() {
let s = "Qwen/Q3";
let id: ModelSourceId = s.parse().unwrap();
assert_eq!(id.to_string(), s);
}
#[test]
fn rejects_empty() {
assert_eq!("".parse::<ModelSourceId>().unwrap_err(), ParseError::Empty);
}
#[test]
fn rejects_missing_slash() {
match "Qwen".parse::<ModelSourceId>().unwrap_err() {
ParseError::MissingSlash(s) => assert_eq!(s, "Qwen"),
other => panic!("expected MissingSlash, got {other:?}"),
}
match "huggingface:Qwen".parse::<ModelSourceId>().unwrap_err() {
ParseError::MissingSlash(s) => assert_eq!(s, "huggingface:Qwen"),
other => panic!("expected MissingSlash, got {other:?}"),
}
}
#[test]
fn rejects_empty_scheme() {
match ":Qwen/Q3".parse::<ModelSourceId>().unwrap_err() {
ParseError::EmptyScheme(s) => assert_eq!(s, ":Qwen/Q3"),
other => panic!("expected EmptyScheme, got {other:?}"),
}
}
#[test]
fn rejects_scheme_with_slash() {
match "hugg/ingface:Q/N".parse::<ModelSourceId>().unwrap_err() {
ParseError::SchemeContainsSlash(s) => assert_eq!(s, "hugg/ingface:Q/N"),
other => panic!("expected SchemeContainsSlash, got {other:?}"),
}
}
#[test]
fn rejects_empty_org_or_name() {
match "huggingface:/N".parse::<ModelSourceId>().unwrap_err() {
ParseError::EmptyOrg(_) => {}
other => panic!("expected EmptyOrg, got {other:?}"),
}
match "huggingface:Q/".parse::<ModelSourceId>().unwrap_err() {
ParseError::EmptyName(_) => {}
other => panic!("expected EmptyName, got {other:?}"),
}
}
#[test]
fn rejects_name_with_colon() {
match "huggingface:Q/N:weird"
.parse::<ModelSourceId>()
.unwrap_err()
{
ParseError::NameContainsColon(s) => assert_eq!(s, "huggingface:Q/N:weird"),
other => panic!("expected NameContainsColon, got {other:?}"),
}
}
#[test]
fn serde_roundtrips_via_struct() {
// We serialize as a struct (scheme/org/name fields) so the
// shape is self-describing in API payloads. Callers that want
// the compact `scheme:org/name` string use `Display`/`FromStr`.
let id = ModelSourceId::new("helexa", "Helexa", "Qwen3.6-27B");
let json = serde_json::to_string(&id).unwrap();
let back: ModelSourceId = serde_json::from_str(&json).unwrap();
assert_eq!(back, id);
}
}

View File

@@ -75,7 +75,11 @@ pub fn openai_to_anthropic(resp: ChatCompletionResponse) -> MessagesResponse {
MessageContent::Text(t) => t,
MessageContent::Parts(parts) => serde_json::to_string(&parts).unwrap_or_default(),
};
let stop = c.finish_reason.map(|r| map_stop_reason(&r));
let stop = c.finish_reason.map(|r| match r.as_str() {
"stop" => "end_turn".to_string(),
"length" => "max_tokens".to_string(),
other => other.to_string(),
});
(text, stop)
}
None => (String::new(), None),
@@ -104,374 +108,3 @@ pub fn openai_to_anthropic(resp: ChatCompletionResponse) -> MessagesResponse {
extra: Value::Null,
}
}
// ── Streaming SSE translation (#24) ──────────────────────────────────
/// Map an OpenAI `finish_reason` to an Anthropic `stop_reason`.
pub fn map_stop_reason(openai: &str) -> String {
match openai {
"stop" => "end_turn".to_string(),
"length" => "max_tokens".to_string(),
"tool_calls" => "tool_use".to_string(),
other => other.to_string(),
}
}
/// Stateful OpenAI-SSE → Anthropic-SSE event translator.
///
/// Feed each parsed OpenAI [`crate::openai::ChatCompletionChunk`] to
/// [`on_chunk`](Self::on_chunk) and call [`finish`](Self::finish) on
/// `[DONE]` (or upstream EOF); both return ordered
/// `(event_name, payload)` pairs ready to be framed as
/// `event: <name>\ndata: <payload>\n\n`. The translation is stateless
/// across requests — one instance per stream — and never buffers
/// content: every text delta maps to a `content_block_delta`
/// immediately.
///
/// Event sequence produced (per Anthropic's streaming spec):
/// `message_start` → `content_block_start` / `content_block_delta`* /
/// `content_block_stop` (text and `tool_use` blocks, indexed) →
/// `message_delta` (stop_reason + output usage) → `message_stop`.
#[derive(Debug, Default)]
pub struct AnthropicStreamTranslator {
started: bool,
finished: bool,
/// Index of the currently-open content block, with its kind.
open_block: Option<(u32, OpenBlock)>,
next_index: u32,
stop_reason: Option<String>,
usage: Option<Usage>,
/// Visible text deltas counted as an output-token estimate for
/// streams whose upstream never sends a usage frame (neuron emits
/// one chunk per token, so this is exact there).
text_deltas: u64,
}
#[derive(Debug, PartialEq, Eq)]
enum OpenBlock {
Text,
ToolUse,
}
impl AnthropicStreamTranslator {
pub fn new() -> Self {
Self::default()
}
pub fn on_chunk(&mut self, chunk: &crate::openai::ChatCompletionChunk) -> Vec<(String, Value)> {
let mut out = Vec::new();
if !self.started {
self.started = true;
out.push((
"message_start".to_string(),
json!({
"type": "message_start",
"message": {
// Upstream ids are opaque to Anthropic clients;
// prefix for shape-compatibility with msg_* ids.
"id": format!("msg_{}", chunk.id),
"type": "message",
"role": "assistant",
"content": [],
"model": chunk.model,
"stop_reason": null,
"stop_sequence": null,
// Input tokens are unknown until (if ever) a
// usage frame arrives; corrected in
// message_delta. Anthropic clients sum deltas.
"usage": { "input_tokens": 0, "output_tokens": 0 }
}
}),
));
}
if let Some(usage) = &chunk.usage {
self.usage = Some(usage.clone());
}
for choice in &chunk.choices {
if let Some(text) = choice.delta.get("content").and_then(Value::as_str)
&& !text.is_empty()
{
self.ensure_text_block(&mut out);
self.text_deltas += 1;
let index = self.open_block.as_ref().map(|(i, _)| *i).unwrap_or(0);
out.push((
"content_block_delta".to_string(),
json!({
"type": "content_block_delta",
"index": index,
"delta": { "type": "text_delta", "text": text }
}),
));
}
if let Some(calls) = choice.delta.get("tool_calls").and_then(Value::as_array) {
for call in calls {
let name = call
.get("function")
.and_then(|f| f.get("name"))
.and_then(Value::as_str);
let arguments = call
.get("function")
.and_then(|f| f.get("arguments"))
.and_then(Value::as_str)
.unwrap_or_default();
if let Some(name) = name {
// A named entry begins a new tool_use block.
self.close_open_block(&mut out);
let id = call
.get("id")
.and_then(Value::as_str)
.unwrap_or("toolu_unknown");
let index = self.next_index;
self.next_index += 1;
self.open_block = Some((index, OpenBlock::ToolUse));
out.push((
"content_block_start".to_string(),
json!({
"type": "content_block_start",
"index": index,
"content_block": {
"type": "tool_use",
"id": id,
"name": name,
"input": {}
}
}),
));
}
if !arguments.is_empty()
&& let Some((index, OpenBlock::ToolUse)) = &self.open_block
{
out.push((
"content_block_delta".to_string(),
json!({
"type": "content_block_delta",
"index": index,
"delta": {
"type": "input_json_delta",
"partial_json": arguments
}
}),
));
}
}
}
if let Some(reason) = &choice.finish_reason {
self.stop_reason = Some(map_stop_reason(reason));
}
}
out
}
/// Close the stream: emits the trailing block-stop, message_delta
/// (stop_reason + output usage) and message_stop. Idempotent.
pub fn finish(&mut self) -> Vec<(String, Value)> {
let mut out = Vec::new();
if self.finished || !self.started {
self.finished = true;
return out;
}
self.finished = true;
self.close_open_block(&mut out);
let output_tokens = self
.usage
.as_ref()
.map(|u| u.completion_tokens)
.unwrap_or(self.text_deltas);
let mut usage = json!({ "output_tokens": output_tokens });
if let Some(u) = &self.usage {
usage["input_tokens"] = json!(u.prompt_tokens);
}
out.push((
"message_delta".to_string(),
json!({
"type": "message_delta",
"delta": {
"stop_reason": self.stop_reason.as_deref().unwrap_or("end_turn"),
"stop_sequence": null
},
"usage": usage
}),
));
out.push((
"message_stop".to_string(),
json!({ "type": "message_stop" }),
));
out
}
fn ensure_text_block(&mut self, out: &mut Vec<(String, Value)>) {
match &self.open_block {
Some((_, OpenBlock::Text)) => {}
_ => {
self.close_open_block(out);
let index = self.next_index;
self.next_index += 1;
self.open_block = Some((index, OpenBlock::Text));
out.push((
"content_block_start".to_string(),
json!({
"type": "content_block_start",
"index": index,
"content_block": { "type": "text", "text": "" }
}),
));
}
}
}
fn close_open_block(&mut self, out: &mut Vec<(String, Value)>) {
if let Some((index, _)) = self.open_block.take() {
out.push((
"content_block_stop".to_string(),
json!({ "type": "content_block_stop", "index": index }),
));
}
}
}
#[cfg(test)]
mod stream_tests {
use super::*;
use crate::openai::{ChatCompletionChunk, ChunkChoice};
fn chunk(delta: Value, finish: Option<&str>) -> ChatCompletionChunk {
ChatCompletionChunk {
id: "abc123".into(),
object: "chat.completion.chunk".into(),
created: 1,
model: "Qwen/Qwen3-8B".into(),
choices: vec![ChunkChoice {
index: 0,
delta,
finish_reason: finish.map(String::from),
extra: Value::Null,
}],
usage: None,
extra: Value::Null,
}
}
fn names(events: &[(String, Value)]) -> Vec<&str> {
events.iter().map(|(n, _)| n.as_str()).collect()
}
#[test]
fn text_stream_produces_full_anthropic_sequence() {
let mut t = AnthropicStreamTranslator::new();
let mut all = Vec::new();
all.extend(t.on_chunk(&chunk(json!({"role": "assistant"}), None)));
all.extend(t.on_chunk(&chunk(json!({"content": "Hel"}), None)));
all.extend(t.on_chunk(&chunk(json!({"content": "lo"}), None)));
all.extend(t.on_chunk(&chunk(json!({}), Some("stop"))));
all.extend(t.finish());
assert_eq!(
names(&all),
vec![
"message_start",
"content_block_start",
"content_block_delta",
"content_block_delta",
"content_block_stop",
"message_delta",
"message_stop",
]
);
// message_start carries role/model; deltas carry the text.
assert_eq!(all[0].1["message"]["model"], "Qwen/Qwen3-8B");
assert_eq!(all[2].1["delta"]["text"], "Hel");
assert_eq!(all[3].1["delta"]["text"], "lo");
// stop → end_turn; without a usage frame the output count
// falls back to the delta count (engine-exact for neuron's
// one-chunk-per-token streams).
let md = &all[5].1;
assert_eq!(md["delta"]["stop_reason"], "end_turn");
assert_eq!(md["usage"]["output_tokens"], 2);
}
#[test]
fn length_maps_to_max_tokens_and_missing_finish_defaults_to_end_turn() {
let mut t = AnthropicStreamTranslator::new();
t.on_chunk(&chunk(json!({"content": "x"}), Some("length")));
let fin = t.finish();
assert_eq!(fin[1].1["delta"]["stop_reason"], "max_tokens");
let mut t2 = AnthropicStreamTranslator::new();
t2.on_chunk(&chunk(json!({"content": "x"}), None));
let fin2 = t2.finish();
assert_eq!(fin2[1].1["delta"]["stop_reason"], "end_turn");
}
#[test]
fn tool_call_becomes_tool_use_block() {
let mut t = AnthropicStreamTranslator::new();
let mut all = Vec::new();
all.extend(t.on_chunk(&chunk(json!({"content": "Let me check."}), None)));
all.extend(t.on_chunk(&chunk(
json!({"tool_calls": [{
"index": 0,
"id": "call_7",
"function": {"name": "get_weather", "arguments": "{\"city\":\"Brno\"}"}
}]}),
None,
)));
all.extend(t.on_chunk(&chunk(json!({}), Some("tool_calls"))));
all.extend(t.finish());
assert_eq!(
names(&all),
vec![
"message_start",
"content_block_start", // text
"content_block_delta", // text delta
"content_block_stop", // text closed by tool block
"content_block_start", // tool_use
"content_block_delta", // input_json_delta
"content_block_stop",
"message_delta",
"message_stop",
]
);
let tool_start = &all[4].1;
assert_eq!(tool_start["content_block"]["type"], "tool_use");
assert_eq!(tool_start["content_block"]["id"], "call_7");
assert_eq!(tool_start["content_block"]["name"], "get_weather");
assert_eq!(tool_start["index"], 1);
assert_eq!(all[5].1["delta"]["partial_json"], "{\"city\":\"Brno\"}");
assert_eq!(all[7].1["delta"]["stop_reason"], "tool_use");
}
#[test]
fn usage_frame_feeds_message_delta() {
let mut t = AnthropicStreamTranslator::new();
t.on_chunk(&chunk(json!({"content": "hi"}), Some("stop")));
let mut usage_chunk = chunk(json!({}), None);
usage_chunk.choices.clear();
usage_chunk.usage = Some(crate::openai::Usage {
prompt_tokens: 225,
completion_tokens: 42,
total_tokens: 267,
});
t.on_chunk(&usage_chunk);
let fin = t.finish();
let md = &fin[1].1;
assert_eq!(md["usage"]["output_tokens"], 42);
assert_eq!(md["usage"]["input_tokens"], 225);
}
#[test]
fn finish_is_idempotent_and_silent_without_start() {
let mut t = AnthropicStreamTranslator::new();
assert!(t.finish().is_empty(), "no events for an empty stream");
assert!(t.finish().is_empty());
let mut t2 = AnthropicStreamTranslator::new();
t2.on_chunk(&chunk(json!({"content": "x"}), None));
assert!(!t2.finish().is_empty());
assert!(t2.finish().is_empty(), "second finish must emit nothing");
}
}

View File

@@ -24,7 +24,6 @@ tokio-stream.workspace = true
eventsource-stream.workspace = true
bytes = "1"
urlencoding = "2"
url = "2"
[dev-dependencies]
tokio = { workspace = true, features = ["test-util"] }

View File

@@ -1,178 +0,0 @@
//! Streaming Anthropic SSE translation (#24).
//!
//! The `/v1/messages` handler translates the request envelope to
//! OpenAI before proxying (see `cortex_core::translate`); this module
//! completes the round trip for `stream: true` — the upstream OpenAI
//! SSE stream is re-framed, event by event, into Anthropic's
//! `message_start` / `content_block_*` / `message_delta` /
//! `message_stop` sequence as it arrives. True streaming: each
//! upstream chunk is translated and forwarded immediately; nothing is
//! buffered beyond the current SSE event's bytes.
//!
//! The translation state machine itself is pure and lives in
//! [`cortex_core::translate::AnthropicStreamTranslator`]; this module
//! owns the wire concerns — splitting the upstream byte stream into
//! SSE events, parsing `data:` payloads, and framing the translated
//! events as `event: <name>\ndata: <json>\n\n`.
use axum::body::Body;
use axum::http::StatusCode;
use axum::response::Response;
use bytes::Bytes;
use cortex_core::openai::ChatCompletionChunk;
use cortex_core::translate::AnthropicStreamTranslator;
use futures::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
/// Forward the translated OpenAI request to the upstream node and
/// return the response translated to Anthropic SSE framing.
pub async fn stream_translated(
client: &reqwest::Client,
endpoint: &str,
openai_body: axum::body::Bytes,
model_id: &str,
node_name: &str,
) -> Response {
let url = format!("{endpoint}/v1/chat/completions");
tracing::info!(
handler = "anthropic_messages",
model = %model_id,
node = %node_name,
url = %url,
"proxying streaming request (anthropic SSE translation)"
);
let upstream = match client
.post(&url)
.header("content-type", "application/json")
.body(openai_body)
.send()
.await
{
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "anthropic_messages",
node = %node_name,
url = %url,
error = %e,
"anthropic stream: upstream request failed"
);
return anthropic_error(StatusCode::BAD_GATEWAY, "upstream request failed");
}
};
let status = upstream.status();
if !status.is_success() {
tracing::warn!(
handler = "anthropic_messages",
node = %node_name,
url = %url,
status = status.as_u16(),
"anthropic stream: upstream returned non-2xx"
);
return anthropic_error(
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY),
"upstream returned an error",
);
}
// Bounded channel: a slow client back-pressures the pump task,
// which back-pressures the upstream read — same propagation
// discipline as neuron's own projectors.
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, std::convert::Infallible>>(32);
let node = node_name.to_string();
tokio::spawn(async move {
let mut upstream = upstream.bytes_stream();
let mut translator = AnthropicStreamTranslator::new();
let mut buf: Vec<u8> = Vec::new();
let mut done = false;
'outer: while let Some(block) = upstream.next().await {
let block = match block {
Ok(b) => b,
Err(e) => {
tracing::warn!(node = %node, error = %e, "anthropic stream: upstream read failed mid-stream");
break;
}
};
buf.extend_from_slice(&block);
// SSE events are separated by a blank line.
while let Some(pos) = find_event_boundary(&buf) {
let event: Vec<u8> = buf.drain(..pos + 2).collect();
let text = String::from_utf8_lossy(&event);
for line in text.lines() {
let Some(data) = line.strip_prefix("data:") else {
continue;
};
let data = data.trim();
if data == "[DONE]" {
done = true;
if !send_frames(&tx, translator.finish()).await {
break 'outer;
}
continue;
}
let Ok(chunk) = serde_json::from_str::<ChatCompletionChunk>(data) else {
tracing::debug!(node = %node, "anthropic stream: unparsable upstream frame skipped");
continue;
};
if !send_frames(&tx, translator.on_chunk(&chunk)).await {
break 'outer;
}
}
}
}
// Upstream ended without [DONE] (error or truncation): still
// close the Anthropic event sequence so clients aren't left
// with an unterminated message.
if !done {
let _ = send_frames(&tx, translator.finish()).await;
}
});
Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/event-stream")
.header("cache-control", "no-cache")
.body(Body::from_stream(ReceiverStream::new(rx)))
.unwrap_or_else(|_| {
anthropic_error(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to build response",
)
})
}
/// `\n\n` boundary of the first complete SSE event in `buf`, if any.
fn find_event_boundary(buf: &[u8]) -> Option<usize> {
buf.windows(2).position(|w| w == b"\n\n")
}
/// Render translated events as SSE frames and send them. Returns
/// `false` when the client has gone away (receiver dropped).
async fn send_frames(
tx: &tokio::sync::mpsc::Sender<Result<Bytes, std::convert::Infallible>>,
events: Vec<(String, serde_json::Value)>,
) -> bool {
for (name, payload) in events {
let frame = format!("event: {name}\ndata: {payload}\n\n");
if tx.send(Ok(Bytes::from(frame))).await.is_err() {
return false;
}
}
true
}
/// Anthropic-shaped error body (`{"type":"error","error":{...}}`).
fn anthropic_error(status: StatusCode, message: &str) -> Response {
let body = serde_json::json!({
"type": "error",
"error": { "type": "api_error", "message": message }
});
Response::builder()
.status(status)
.header("content-type", "application/json")
.body(Body::from(body.to_string()))
.expect("static error response must build")
}

View File

@@ -20,7 +20,6 @@ pub fn api_routes() -> Router<Arc<CortexState>> {
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/v1/responses", post(responses))
.route("/v1/models", get(list_models))
.route("/v1/messages", post(anthropic_messages))
.route("/health", get(health))
@@ -35,94 +34,23 @@ async fn chat_completions(
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => {
tracing::warn!(
handler = "chat_completions",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "chat_completions",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(e.http_status(), &e.to_string());
}
Err(e) => return error_response(404, &e.to_string()),
};
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
touch_model(&fleet, &route.node_name, &model_id).await;
let body = rewrite_model_in_body(body, &route.resolved_model_id);
proxy_with_metrics(
&fleet,
&route,
"/v1/chat/completions",
headers,
body,
&route.resolved_model_id,
)
.await
}
/// `POST /v1/responses` — proxy to the appropriate backend node.
///
/// Same routing shape as [`chat_completions`]: extract `model` from
/// the body, resolve to a node, forward verbatim. No translation —
/// neuron speaks the Responses API natively (see
/// `crates/neuron/src/wire/openai_responses.rs`), so the gateway is
/// a pass-through. Streaming and non-streaming are handled
/// identically; the upstream `Content-Type` (text/event-stream vs.
/// application/json) propagates through the proxy.
async fn responses(
State(fleet): State<Arc<CortexState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => {
tracing::warn!(
handler = "responses",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "responses",
model = %model_id,
error = %e,
"route resolve failed"
);
return error_response(e.http_status(), &e.to_string());
}
};
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
let body = rewrite_model_in_body(body, &route.resolved_model_id);
proxy_with_metrics(
&fleet,
&route,
"/v1/responses",
headers,
body,
&route.resolved_model_id,
&model_id,
)
.await
}
@@ -135,63 +63,29 @@ async fn completions(
) -> Response {
let model_id = match extract_model(&body) {
Some(m) => m,
None => {
tracing::warn!(
handler = "completions",
"rejected: missing 'model' field in request body"
);
return error_response(400, "missing 'model' field in request body");
}
None => return error_response(400, "missing 'model' field in request body"),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "completions",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(e.http_status(), &e.to_string());
}
Err(e) => return error_response(404, &e.to_string()),
};
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
touch_model(&fleet, &route.node_name, &model_id).await;
let body = rewrite_model_in_body(body, &route.resolved_model_id);
proxy_with_metrics(
&fleet,
&route,
"/v1/completions",
headers,
body,
&route.resolved_model_id,
)
.await
proxy_with_metrics(&fleet, &route, "/v1/completions", headers, body, &model_id).await
}
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
async fn anthropic_messages(
State(fleet): State<Arc<CortexState>>,
_headers: HeaderMap,
headers: HeaderMap,
body: Bytes,
) -> Response {
// Parse as Anthropic request.
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "anthropic_messages",
error = %e,
"rejected: invalid Anthropic request body"
);
return error_response(400, "invalid Anthropic request body");
}
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
};
let model_id = anth_req.model.clone();
@@ -201,43 +95,18 @@ async fn anthropic_messages(
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
let openai_body = match serde_json::to_vec(&openai_req) {
Ok(b) => Bytes::from(b),
Err(e) => {
tracing::error!(
handler = "anthropic_messages",
model = %model_id,
error = %e,
"internal: failed to serialise translated OpenAI request"
);
return error_response(500, "internal translation error");
}
Err(e) => return error_response(500, &format!("translation error: {e}")),
};
let route = match router::resolve(&fleet, &model_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
error = %e,
"route resolve failed"
);
// RouteError's Display strings are short and informative
// ("model 'X' not found...", "no healthy nodes available")
// — fine to surface to the caller. The warn above carries
// any extra context for operators.
return error_response(e.http_status(), &e.to_string());
}
Err(e) => return error_response(404, &e.to_string()),
};
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
// Swap the alias for the concrete id in the translated body so
// neuron's harness sees a model name that matches what it has
// loaded.
let openai_body = rewrite_model_in_body(openai_body, &route.resolved_model_id);
touch_model(&fleet, &route.node_name, &model_id).await;
let labels = [
("model", route.resolved_model_id.clone()),
("model", model_id.clone()),
("node", route.node_name.clone()),
];
metrics::counter!("cortex_requests_total", &labels).increment(1);
@@ -247,37 +116,31 @@ async fn anthropic_messages(
let start = Instant::now();
if is_streaming {
// Anthropic SSE translation (#24): upstream speaks OpenAI SSE;
// re-frame it event-by-event into Anthropic's message_start /
// content_block_* / message_delta / message_stop sequence.
let resp = crate::anthropic_sse::stream_translated(
// TODO: streaming Anthropic translation requires converting SSE format.
// For now, proxy the OpenAI SSE stream directly (clients that can handle
// OpenAI SSE will work; full Anthropic SSE translation is a follow-up).
let result = proxy::forward_request(
&fleet.http_client,
&route.endpoint,
&route,
"/v1/chat/completions",
headers,
openai_body,
&model_id,
&route.node_name,
)
.await;
metrics::histogram!("cortex_request_duration_seconds", &labels)
.record(start.elapsed().as_secs_f64());
if !resp.status().is_success() {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
match result {
Ok(resp) => resp,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
e.into_response()
}
}
resp
} else {
// Non-streaming: proxy, buffer full response, translate back to Anthropic.
let target_url = format!("{}/v1/chat/completions", route.endpoint);
tracing::info!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
cold_start = route.cold_start,
"proxying request"
);
let upstream_resp = fleet
.http_client
.post(&target_url)
.post(format!("{}/v1/chat/completions", route.endpoint))
.body(openai_body)
.header("content-type", "application/json")
.send()
@@ -287,49 +150,22 @@ async fn anthropic_messages(
Ok(r) => r,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
"upstream request failed (network)"
);
return error_response(502, "upstream request failed");
return error_response(502, &format!("upstream request failed: {e}"));
}
};
let upstream_status = upstream_resp.status();
if !upstream_status.is_success() {
if !upstream_resp.status().is_success() {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
let status = upstream_status.as_u16();
let status = upstream_resp.status().as_u16();
let body = upstream_resp.text().await.unwrap_or_default();
let body_snippet = body.chars().take(512).collect::<String>();
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
status,
body = %body_snippet,
"upstream returned non-2xx"
);
return error_response(status, &format!("upstream returned {status}"));
return error_response(status, &format!("upstream error: {body}"));
}
let body_bytes = match upstream_resp.bytes().await {
Ok(b) => b,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
"failed to read upstream response body"
);
return error_response(502, "failed to read upstream response");
return error_response(502, &format!("failed to read upstream response: {e}"));
}
};
@@ -338,20 +174,7 @@ async fn anthropic_messages(
Ok(r) => r,
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
let body_snippet = String::from_utf8_lossy(&body_bytes)
.chars()
.take(512)
.collect::<String>();
tracing::warn!(
handler = "anthropic_messages",
model = %model_id,
node = %route.node_name,
url = %target_url,
error = %e,
body = %body_snippet,
"failed to parse upstream response as OpenAI ChatCompletionResponse"
);
return error_response(502, "malformed upstream response");
return error_response(502, &format!("failed to parse upstream response: {e}"));
}
};
@@ -362,65 +185,12 @@ async fn anthropic_messages(
}
}
/// `GET /v1/models` — union of (catalogue × topology feasibility) and
/// (currently loaded somewhere). The result is what the fleet *could*
/// serve, not just what's already loaded — so OpenAI-compatible tools
/// see every model the operator has provisioned, and cortex
/// transparently cold-loads the first time one is requested.
/// `GET /v1/models` — aggregate models from all nodes.
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
use std::collections::HashMap;
let now = Utc::now().timestamp() as u64;
let nodes = fleet.nodes.read().await;
let catalogue = &fleet.catalogue;
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
std::collections::HashMap::new();
let mut entries: HashMap<String, CortexModelEntry> = HashMap::new();
// Pass 1: catalogue × topology. For every catalogue profile, find
// healthy neurons whose discovered devices satisfy the profile.
// Catalogue-defined models surface here even if nothing has loaded
// them yet — that's the point of the unified endpoint.
for profile in &catalogue.models {
let mut feasible_on = Vec::new();
for node in nodes.values() {
if !node.healthy {
continue;
}
let Some(disc) = node.discovery.as_ref() else {
continue;
};
if profile.is_feasible_on(&node.name, &disc.devices) {
feasible_on.push(node.name.clone());
}
}
if feasible_on.is_empty() {
// The catalogue lists this model but no neuron's topology
// matches — surface it as not-loaded with no feasible
// location. Hides nothing; lets operators see why a
// configured model isn't reachable.
feasible_on.clear();
}
entries.insert(
profile.id.clone(),
CortexModelEntry {
id: profile.id.clone(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: false,
feasible_on,
locations: Vec::new(),
// Catalogue profiles don't declare capabilities yet;
// the union is filled in Pass 2 from loaded locations.
capabilities: Vec::new(),
},
);
}
// Pass 2: layer the actually-loaded state on top. For each
// (node, model) entry, attach a ModelLocation. If the model isn't
// in the catalogue, create a new CortexModelEntry from scratch —
// cortex doesn't refuse to surface a manually-loaded model just
// because the operator didn't enumerate it in models.toml.
for node in nodes.values() {
for (model_id, entry) in &node.models {
let location = ModelLocation {
@@ -428,121 +198,19 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
status: entry.status,
vram_estimate_mb: entry.vram_estimate_mb,
};
let was_loaded = matches!(entry.status, cortex_core::node::ModelStatus::Loaded);
entries
model_map
.entry(model_id.clone())
.and_modify(|e| {
e.locations.push(location.clone());
if was_loaded {
e.loaded = true;
}
// Union the per-node capabilities so a model loaded
// on several neurons reports every modality any of
// them advertises.
for cap in &entry.capabilities {
if !e.capabilities.contains(cap) {
e.capabilities.push(cap.clone());
}
}
})
.and_modify(|e| e.locations.push(location.clone()))
.or_insert_with(|| CortexModelEntry {
id: model_id.clone(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: was_loaded,
// Not in catalogue — cortex has no opinion on
// feasibility; leave empty.
feasible_on: Vec::new(),
locations: vec![location],
capabilities: entry.capabilities.clone(),
});
}
}
// Pass 3: surface pre-warming models. Each neuron's `/health`
// activation snapshot (polled separately from /models) reports
// `in_progress` (the model currently materialising) and `pending`
// (queued behind it). Neither appears on the neuron's `/models`
// yet — that endpoint only knows about fully-loaded handles — so
// without this pass a client polling `/v1/models` during pre-warm
// sees Qwen3.6-27B with no location and concludes "not there".
// Synthesising a Loading location instead tells clients the model
// is on its way. Idempotent against Pass 2: if a Loading location
// for this node already exists (shouldn't, but be safe) we skip.
for node in nodes.values() {
let Some(activation) = node.activation.as_ref() else {
continue;
};
let mut loading_ids: Vec<&str> = Vec::new();
if let Some(id) = activation.in_progress.as_deref() {
loading_ids.push(id);
}
for id in &activation.pending {
loading_ids.push(id.as_str());
}
for model_id in loading_ids {
let location = ModelLocation {
node: node.name.clone(),
status: cortex_core::node::ModelStatus::Loading,
vram_estimate_mb: None,
};
entries
.entry(model_id.to_string())
.and_modify(|e| {
let already = e.locations.iter().any(|l| {
l.node == node.name && l.status == cortex_core::node::ModelStatus::Loading
});
if !already {
e.locations.push(location.clone());
}
})
.or_insert_with(|| CortexModelEntry {
id: model_id.to_string(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: false,
feasible_on: Vec::new(),
locations: vec![location],
// A model that's only mid-prewarm has no loaded
// location to read capabilities from yet.
capabilities: Vec::new(),
});
}
}
let data: Vec<Value> = model_map.values().map(|e| json!(e)).collect();
// Pass 4: surface aliases as their own entries pointing at the
// same locations as the target id, so a client browsing /v1/models
// sees "helexa/small" / "helexa/balanced" / "helexa/large" (or
// whatever the operator defined) and can request inference
// against them directly. Aliases that point at unknown targets
// are skipped — surfacing a dead alias would be misleading.
for (alias, target) in &catalogue.aliases {
let Some(target_entry) = entries.get(target).cloned() else {
tracing::warn!(
alias = alias,
target = target,
"alias points at a model not present in catalogue or fleet; skipping"
);
continue;
};
entries.insert(
alias.clone(),
CortexModelEntry {
id: alias.clone(),
object: "model".into(),
created: now,
owned_by: "helexa".into(),
loaded: target_entry.loaded,
feasible_on: target_entry.feasible_on,
locations: target_entry.locations,
capabilities: target_entry.capabilities,
},
);
}
let data: Vec<Value> = entries.values().map(|e| json!(e)).collect();
Json(json!({
"object": "list",
"data": data,
@@ -586,8 +254,7 @@ async fn proxy_with_metrics(
}
let start = Instant::now();
let result =
proxy::forward_request(&fleet.http_client, route, path, headers, body, model_id).await;
let result = proxy::forward_request(&fleet.http_client, route, path, headers, body).await;
let duration = start.elapsed();
match result {
@@ -598,9 +265,6 @@ async fn proxy_with_metrics(
}
Err(e) => {
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
// proxy::forward_request already warn'd with wire-level
// detail (target URL, error, status). ProxyError::into_response
// now returns a generic message — no body leak.
e.into_response()
}
}
@@ -621,38 +285,6 @@ fn extract_model(body: &[u8]) -> Option<String> {
v.get("model")?.as_str().map(|s| s.to_string())
}
/// Rewrite the `model` field of an OpenAI-style JSON request body to
/// the resolved concrete id. Returns the original bytes if `new_model`
/// matches what's already there or the body fails to parse — the
/// caller has already extracted `model` via `extract_model`, so a
/// parse failure here would only happen on a body the client crafted
/// to defeat us, and we'd rather proxy it unchanged than 500.
///
/// Needed because neuron rejects requests whose `model` field doesn't
/// match a loaded model, so a client that sends `model: "helexa/small"`
/// would hit a 404 at the harness unless we swap it for the concrete
/// id the alias resolved to.
fn rewrite_model_in_body(body: Bytes, new_model: &str) -> Bytes {
let Ok(mut v) = serde_json::from_slice::<Value>(&body) else {
return body;
};
let needs_rewrite = v
.get("model")
.and_then(|m| m.as_str())
.map(|m| m != new_model)
.unwrap_or(false);
if !needs_rewrite {
return body;
}
if let Value::Object(obj) = &mut v {
obj.insert("model".into(), Value::String(new_model.to_string()));
}
match serde_json::to_vec(&v) {
Ok(bytes) => Bytes::from(bytes),
Err(_) => body,
}
}
fn error_response(status: u16, message: &str) -> Response {
let code = axum::http::StatusCode::from_u16(status)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);

View File

@@ -1,4 +1,3 @@
pub mod anthropic_sse;
pub mod evictor;
pub mod handlers;
pub mod metrics;

View File

@@ -46,14 +46,6 @@ fn describe_metrics() {
"Generation throughput in tokens per second"
);
metrics::describe_counter!("cortex_requests_total", "Total number of proxied requests");
metrics::describe_counter!(
"cortex_prompt_tokens_total",
"Total prompt tokens reported by upstream usage objects"
);
metrics::describe_counter!(
"cortex_completion_tokens_total",
"Total completion tokens reported by upstream usage objects"
);
metrics::describe_counter!(
"cortex_request_errors_total",
"Total number of failed proxy requests"

View File

@@ -3,7 +3,6 @@
use crate::state::CortexState;
use chrono::Utc;
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
use cortex_core::harness::ModelInfo;
use cortex_core::node::{ModelEntry, ModelStatus};
use std::sync::Arc;
@@ -26,59 +25,7 @@ pub async fn poll_once(fleet: &CortexState) {
}
}
/// One-shot fetch of `GET /discovery`. Cached on the NodeState forever
/// after the first success — topology is invariant for a given neuron
/// process. Skipped when the cache is already populated.
async fn maybe_poll_discovery(fleet: &CortexState, name: &str, endpoint: &str) {
{
let nodes = fleet.nodes.read().await;
match nodes.get(name) {
Some(n) if n.discovery.is_some() => return,
_ => {}
}
}
let url = format!("{endpoint}/discovery");
let resp = match fleet
.http_client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await
{
Ok(r) if r.status().is_success() => r,
Ok(r) => {
tracing::debug!(node = name, status = %r.status(), "discovery probe non-success");
return;
}
Err(e) => {
tracing::debug!(node = name, error = %e, "discovery probe unreachable");
return;
}
};
match resp.json::<DiscoveryResponse>().await {
Ok(d) => {
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(name) {
tracing::info!(
node = name,
hostname = %d.hostname,
devices = d.devices.len(),
"discovery cached"
);
node.discovery = Some(d);
}
}
Err(e) => {
tracing::warn!(node = name, error = %e, "failed to parse /discovery response");
}
}
}
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
// Topology first — cheap once cached, and the router needs it to
// route requests against catalogue entries that aren't loaded yet.
maybe_poll_discovery(fleet, name, endpoint).await;
let url = format!("{endpoint}/models");
let result = fleet
@@ -107,14 +54,12 @@ async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
.and_modify(|e| {
e.status = status;
e.vram_estimate_mb = upstream.vram_used_mb;
e.capabilities = upstream.capabilities.clone();
})
.or_insert_with(|| ModelEntry {
id: upstream.id.clone(),
status,
last_accessed: None,
vram_estimate_mb: upstream.vram_used_mb,
capabilities: upstream.capabilities.clone(),
});
}
@@ -144,51 +89,6 @@ async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
node.healthy = false;
}
}
// Release the write lock before the next HTTP call.
drop(nodes);
// Poll /health for the activation snapshot. We don't want this to
// flip the node to unhealthy on its own — a neuron that's serving
// /models fine is still operational even if /health is briefly
// unavailable — so failures are debug-level and leave the existing
// activation reading in place.
poll_health(fleet, name, endpoint).await;
}
/// Fetch `/health` and stash the activation snapshot on NodeState.
/// Decoupled from the /models poll so a /health glitch doesn't mark
/// the neuron unhealthy or evict the model list.
async fn poll_health(fleet: &CortexState, name: &str, endpoint: &str) {
let url = format!("{endpoint}/health");
let resp = match fleet
.http_client
.get(&url)
.timeout(Duration::from_secs(5))
.send()
.await
{
Ok(r) if r.status().is_success() => r,
Ok(r) => {
tracing::debug!(node = name, status = %r.status(), "/health probe non-success");
return;
}
Err(e) => {
tracing::debug!(node = name, error = %e, "/health probe failed");
return;
}
};
match resp.json::<HealthResponse>().await {
Ok(h) => {
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(name) {
node.activation = Some(h.activation);
}
}
Err(e) => {
tracing::debug!(node = name, error = %e, "failed to parse /health response");
}
}
}
fn parse_status(s: &str) -> ModelStatus {
@@ -196,8 +96,6 @@ fn parse_status(s: &str) -> ModelStatus {
"loaded" => ModelStatus::Loaded,
"unloaded" => ModelStatus::Unloaded,
"reloading" => ModelStatus::Reloading,
"loading" => ModelStatus::Loading,
"recovering" => ModelStatus::Recovering,
_ => ModelStatus::Loaded,
}
}

View File

@@ -1,4 +1,4 @@
//! Streaming HTTP reverse proxy to neuron backends.
//! Streaming HTTP reverse proxy to mistral.rs backends.
//!
//! For streaming requests, SSE chunks are forwarded as they arrive.
//! The proxy captures timing information for metrics but does not
@@ -9,30 +9,16 @@ use anyhow::Result;
use axum::body::Body;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use futures::Stream;
use futures::stream::BoxStream;
use reqwest::Client;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
/// Proxy a request body to the resolved backend node and stream the response.
///
/// Logging contract: every call emits exactly one structured event at
/// info / warn level for operator visibility, regardless of outcome.
/// Network-level failures and non-2xx upstream statuses are warn'd here
/// (closest to the wire); the user-facing response carries only the
/// status code and a generic message — implementation detail (body,
/// error chain) lives in the log, never in the API surface.
pub async fn forward_request(
client: &Client,
route: &RouteDecision,
path: &str,
headers: HeaderMap,
body: bytes::Bytes,
model_id: &str,
) -> Result<Response, ProxyError> {
let request_start = Instant::now();
let url = format!("{}{}", route.endpoint, path);
tracing::info!(
node = %route.node_name,
@@ -51,39 +37,13 @@ pub async fn forward_request(
req_builder = req_builder.header(key, value);
}
let upstream_resp = match req_builder.send().await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
node = %route.node_name,
url = %url,
error = %e,
"proxy: upstream request failed (network)"
);
return Err(ProxyError::Upstream(e));
}
};
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
let upstream_status = upstream_resp.status();
if !upstream_status.is_success() {
// Streaming body — can't snippet without breaking the stream
// pass-through. Log status + URL; the client still gets the
// upstream status, just without the leaked body.
tracing::warn!(
node = %route.node_name,
url = %url,
status = upstream_status.as_u16(),
"proxy: upstream returned non-2xx"
);
}
let status = StatusCode::from_u16(upstream_status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let status =
StatusCode::from_u16(upstream_resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let resp_headers = upstream_resp.headers().clone();
let stream = TokenMetricsStream::new(
Box::pin(upstream_resp.bytes_stream()),
TokenMetrics::new(model_id, &route.node_name, request_start),
);
let stream = upstream_resp.bytes_stream();
let body = Body::from_stream(stream);
@@ -92,261 +52,31 @@ pub async fn forward_request(
response = response.header(key, value);
}
response.body(body).map_err(|e| {
tracing::warn!(
node = %route.node_name,
url = %url,
error = %e,
"proxy: failed to build response"
);
ProxyError::ResponseBuild(e.to_string())
})
response
.body(body)
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
}
#[derive(Debug, thiserror::Error)]
pub enum ProxyError {
#[error("upstream request failed")]
#[error("upstream request failed: {0}")]
Upstream(reqwest::Error),
#[error("failed to build response")]
#[error("failed to build response: {0}")]
ResponseBuild(String),
}
impl IntoResponse for ProxyError {
fn into_response(self) -> Response {
let (status, message) = match &self {
ProxyError::Upstream(_) => (StatusCode::BAD_GATEWAY, "upstream request failed"),
ProxyError::ResponseBuild(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"failed to build response",
),
let status = match &self {
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
let body = serde_json::json!({
"error": {
"message": message,
"message": self.to_string(),
"type": "proxy_error",
}
});
(status, axum::Json(body)).into_response()
}
}
// ── Per-request token metrics (#21) ─────────────────────────────────
//
// The proxy never buffers or re-serialises the upstream body — chunks
// are forwarded verbatim. For metrics it observes each chunk's arrival
// time and keeps a bounded tail of the body text, from which the final
// OpenAI `usage` object (present on the last SSE chunk and on
// non-streaming JSON bodies alike) yields engine-truth token counts.
//
// Emitted per request, labelled {model, node}:
// cortex_time_to_first_token_seconds (histogram) — first body chunk
// cortex_tokens_per_second (histogram) — completion tokens
// over the decode window (first→last chunk); falls back to the
// full request duration for single-chunk (non-streaming) bodies
// cortex_prompt_tokens_total / cortex_completion_tokens_total (counters)
/// Cap on the retained body tail. The usage object rides on the final
/// chunk, so a generous tail is plenty; the cap bounds memory on huge
/// non-streaming bodies.
const TAIL_CAP_BYTES: usize = 64 * 1024;
/// Find the value of the LAST `"key": <integer>` occurrence in `tail`.
/// Pure and chunk-boundary-safe (the tail is contiguous appended text).
/// The quoted-needle form means `completion_tokens` never matches
/// `completion_tokens_details`.
pub(crate) fn last_count_for(tail: &str, key: &str) -> Option<u64> {
let needle = format!("\"{key}\"");
let mut result = None;
for (idx, _) in tail.match_indices(&needle) {
let rest = tail[idx + needle.len()..].trim_start();
let Some(rest) = rest.strip_prefix(':') else {
continue;
};
let rest = rest.trim_start();
let digits: &str = &rest[..rest
.char_indices()
.find(|(_, c)| !c.is_ascii_digit())
.map(|(i, _)| i)
.unwrap_or(rest.len())];
if let Ok(v) = digits.parse::<u64>() {
result = Some(v);
}
}
result
}
struct TokenMetrics {
labels: [(&'static str, String); 2],
request_start: Instant,
first_chunk: Option<Instant>,
last_chunk: Option<Instant>,
tail: String,
finished: bool,
}
impl TokenMetrics {
fn new(model_id: &str, node_name: &str, request_start: Instant) -> Self {
Self {
labels: [
("model", model_id.to_string()),
("node", node_name.to_string()),
],
request_start,
first_chunk: None,
last_chunk: None,
tail: String::new(),
finished: false,
}
}
fn observe(&mut self, chunk: &[u8]) {
let now = Instant::now();
self.first_chunk.get_or_insert(now);
self.last_chunk = Some(now);
self.tail.push_str(&String::from_utf8_lossy(chunk));
if self.tail.len() > TAIL_CAP_BYTES {
// Keep the newest half; the usage object is always at the
// very end of the body. Split at a char boundary.
let mut cut = self.tail.len() - TAIL_CAP_BYTES / 2;
while !self.tail.is_char_boundary(cut) {
cut += 1;
}
self.tail.drain(..cut);
}
}
/// Emit the metrics exactly once — called on clean stream end and
/// from Drop (client disconnect mid-stream still records what we
/// saw).
fn finish(&mut self) {
if self.finished {
return;
}
self.finished = true;
let Some(first) = self.first_chunk else {
return; // no body ever arrived — nothing to record
};
let ttft = first.duration_since(self.request_start).as_secs_f64();
metrics::histogram!("cortex_time_to_first_token_seconds", &self.labels).record(ttft);
if let Some(prompt) = last_count_for(&self.tail, "prompt_tokens") {
metrics::counter!("cortex_prompt_tokens_total", &self.labels).increment(prompt);
}
let Some(completion) = last_count_for(&self.tail, "completion_tokens") else {
return;
};
if completion == 0 {
return;
}
metrics::counter!("cortex_completion_tokens_total", &self.labels).increment(completion);
let last = self.last_chunk.unwrap_or(first);
let decode_window = last.duration_since(first).as_secs_f64();
// Streaming: rate over the decode window (first→last chunk).
// Non-streaming bodies arrive as ~one chunk (window ≈ 0), where
// the only honest denominator is the full request duration.
let secs = if decode_window >= 0.1 {
decode_window
} else {
last.duration_since(self.request_start).as_secs_f64()
};
if secs > 0.0 {
metrics::histogram!("cortex_tokens_per_second", &self.labels)
.record(completion as f64 / secs);
}
}
}
/// Pass-through stream wrapper that feeds [`TokenMetrics`]. Emits on
/// clean end-of-stream; the Drop impl covers client disconnects.
struct TokenMetricsStream {
inner: BoxStream<'static, Result<bytes::Bytes, reqwest::Error>>,
metrics: TokenMetrics,
}
impl TokenMetricsStream {
fn new(
inner: BoxStream<'static, Result<bytes::Bytes, reqwest::Error>>,
metrics: TokenMetrics,
) -> Self {
Self { inner, metrics }
}
}
impl Stream for TokenMetricsStream {
type Item = Result<bytes::Bytes, reqwest::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
this.metrics.observe(&chunk);
Poll::Ready(Some(Ok(chunk)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
this.metrics.finish();
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
impl Drop for TokenMetricsStream {
fn drop(&mut self) {
self.metrics.finish();
}
}
#[cfg(test)]
mod tests {
use super::last_count_for;
#[test]
fn extracts_counts_from_final_sse_usage_chunk() {
let tail = concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n",
"data: {\"choices\":[],\"usage\":{\"prompt_tokens\":225,",
"\"completion_tokens\":42,\"total_tokens\":267}}\n\n",
"data: [DONE]\n\n"
);
assert_eq!(last_count_for(tail, "prompt_tokens"), Some(225));
assert_eq!(last_count_for(tail, "completion_tokens"), Some(42));
}
#[test]
fn extracts_counts_from_non_streaming_body() {
let tail = "{\"choices\":[{\"message\":{\"content\":\"hi\"}}],\
\"usage\":{\"prompt_tokens\": 12, \"completion_tokens\": 7}}";
assert_eq!(last_count_for(tail, "prompt_tokens"), Some(12));
assert_eq!(last_count_for(tail, "completion_tokens"), Some(7));
}
#[test]
fn ignores_details_variants_and_takes_last_occurrence() {
// completion_tokens_details must not shadow completion_tokens,
// and the LAST usage object wins (matters when content echoes
// a usage-shaped string earlier in the stream).
let tail = concat!(
"data: {\"usage\":{\"completion_tokens\":1}}\n\n",
"data: {\"usage\":{\"completion_tokens\":99,",
"\"completion_tokens_details\":{\"reasoning_tokens\":3}}}\n\n"
);
assert_eq!(last_count_for(tail, "completion_tokens"), Some(99));
}
#[test]
fn absent_keys_yield_none() {
assert_eq!(
last_count_for("data: [DONE]\n\n", "completion_tokens"),
None
);
assert_eq!(last_count_for("", "prompt_tokens"), None);
// key present but non-numeric value
assert_eq!(
last_count_for("\"completion_tokens\": null", "completion_tokens"),
None
);
}
}

View File

@@ -2,21 +2,13 @@
//!
//! Given a model ID from an inbound request, determine which node should
//! handle it. Priority:
//! 1. Node where the model is currently `Loaded` → use it.
//! 2. Node where the model is `Unloaded` → use it; neuron's existing
//! lazy-load behaviour will reload before serving the request.
//! 3. Model is in the catalogue → pick a feasible neuron, call
//! `POST /models/load`, wait for the load to complete, then
//! proxy. First-request cold-load latency is acceptable per the
//! unified-endpoint contract.
//! 4. Not in catalogue, not loaded anywhere → 404.
//! 1. Node where the model is currently `Loaded`
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
//! 3. Error: model not found on any node
use crate::state::CortexState;
use cortex_core::catalogue::ModelProfile;
use cortex_core::harness::ModelSpec;
use cortex_core::node::ModelStatus;
use std::sync::Arc;
use std::time::Duration;
/// The routing decision: which node endpoint to proxy the request to.
#[derive(Debug, Clone)]
@@ -24,345 +16,62 @@ pub struct RouteDecision {
pub node_name: String,
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
pub endpoint: String,
/// Whether the model will need to load (cold start). Set to true
/// when we proxied to an `Unloaded` node (lazy load on neuron) or
/// when we just triggered an explicit cold-load via the catalogue
/// path.
/// Whether the model will need to load (cold start).
pub cold_start: bool,
/// The concrete model id we actually routed to. Equal to the
/// caller's requested id unless an alias was resolved (e.g. caller
/// asked for `helexa/small`, this carries `Qwen/Qwen3-1.7B`). The
/// handler uses this to rewrite the request body's `model` field
/// before proxying — neurons reject requests where the body's
/// model name doesn't match a loaded model.
pub resolved_model_id: String,
}
#[derive(Debug, thiserror::Error)]
pub enum RouteError {
#[error("model '{0}' not found on any node and not in catalogue")]
#[error("model '{0}' not found on any node")]
ModelNotFound(String),
#[error("no healthy nodes available")]
NoHealthyNodes,
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
EndpointResolveFailed(String, String),
#[error(
"model '{model_id}' is in the catalogue but no healthy neuron's topology satisfies its constraints"
)]
NoFeasibleNeuron { model_id: String },
#[error("cold-load of '{model_id}' on '{node}' failed: {message}")]
ColdLoadFailed {
model_id: String,
node: String,
message: String,
},
#[error(
"model '{model_id}' is recovering on node '{node}' (device context rebuild in progress) — retry shortly"
)]
ModelRecovering { model_id: String, node: String },
}
impl RouteError {
/// HTTP status the gateway should answer with. `ModelRecovering`
/// is the one transient case (503, retry the same request);
/// everything else keeps the long-standing 404 behaviour.
pub fn http_status(&self) -> u16 {
match self {
RouteError::ModelRecovering { .. } => 503,
_ => 404,
}
}
}
/// Resolve which node should serve a request for the given model.
/// Asks the neuron for the inference endpoint after selecting a node.
pub async fn resolve(
fleet: &Arc<CortexState>,
requested_model_id: &str,
model_id: &str,
) -> Result<RouteDecision, RouteError> {
// Alias resolution first — swap `helexa/small` (etc.) for the
// concrete id before any node lookups so the rest of routing,
// loading, and metrics deal in concrete ids only. `resolve_alias`
// returns the input verbatim when it isn't an alias.
let model_id = fleet.catalogue.resolve_alias(requested_model_id);
if model_id != requested_model_id {
tracing::debug!(
requested = requested_model_id,
resolved = model_id,
"alias resolved"
);
}
// Snapshot loaded / unloaded / recovering state from the poller cache.
let (loaded_route, unloaded_route, recovering_node, any_healthy) = {
let (node_name, neuron_endpoint, cold_start) = {
let nodes = fleet.nodes.read().await;
let mut loaded_route = None;
let mut unloaded_route = None;
let mut recovering_node = None;
let mut any_healthy = false;
let mut loaded_candidate = None;
let mut unloaded_candidate = None;
for node in nodes.values() {
if !node.healthy {
continue;
}
any_healthy = true;
if let Some(entry) = node.models.get(model_id) {
match entry.status {
ModelStatus::Loaded | ModelStatus::Reloading => {
loaded_route = Some((node.name.clone(), node.endpoint.clone(), false));
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false));
break;
}
ModelStatus::Unloaded => {
if unloaded_route.is_none() {
unloaded_route = Some((node.name.clone(), node.endpoint.clone(), true));
if unloaded_candidate.is_none() {
unloaded_candidate =
Some((node.name.clone(), node.endpoint.clone(), true));
}
}
// Auto-recovering (#17/#20): the model is rebuilding
// its device context on this node. Hold the route —
// answer "retry shortly" rather than 404, and do NOT
// fall through to the catalogue cold-load, which
// would race a second placement (and a second copy's
// worth of VRAM) against the in-flight recovery.
ModelStatus::Recovering => {
if recovering_node.is_none() {
recovering_node = Some(node.name.clone());
}
}
// Loading is gateway-synthesised from neuron's
// activation snapshot; it never appears on the
// wire from neuron's `/models`. Skip — the model
// isn't actually servable yet. The pre-existing
// race (catalogue cold_load fires a parallel
// /models/load against the in-flight load) is no
// worse than before; fixing it needs neuron-side
// in-flight tracking on /models/load itself.
ModelStatus::Loading => {}
}
}
}
(loaded_route, unloaded_route, recovering_node, any_healthy)
};
if !any_healthy {
return Err(RouteError::NoHealthyNodes);
}
// Priority 1: already loaded.
if let Some((node_name, neuron_endpoint, cold_start)) = loaded_route {
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
}
// Priority 2: recovering somewhere — transient hold, not a reroute.
if let Some(node) = recovering_node {
return Err(RouteError::ModelRecovering {
model_id: model_id.to_string(),
node,
});
}
// Priority 3: known to neuron but unloaded (neuron's lazy load).
if let Some((node_name, neuron_endpoint, cold_start)) = unloaded_route {
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
}
// Priority 4: catalogue × topology cold-load.
if let Some(profile) = fleet.catalogue.get(model_id) {
let (node_name, neuron_endpoint) = pick_feasible_neuron(fleet, profile).await?;
cold_load(fleet, &node_name, &neuron_endpoint, profile).await?;
return finish(fleet, &node_name, &neuron_endpoint, model_id, true).await;
}
Err(RouteError::ModelNotFound(model_id.to_string()))
}
/// Pick a healthy neuron whose discovered topology satisfies the
/// profile. Preference order:
/// 1. A neuron from `profile.pinned_on` that is healthy + feasible.
/// 2. Otherwise, any healthy + feasible neuron, stable by name.
async fn pick_feasible_neuron(
fleet: &Arc<CortexState>,
profile: &ModelProfile,
) -> Result<(String, String), RouteError> {
let nodes = fleet.nodes.read().await;
let mut candidates: Vec<(String, String, bool)> = Vec::new();
for node in nodes.values() {
if !node.healthy {
continue;
}
let Some(disc) = node.discovery.as_ref() else {
continue;
};
if !profile.is_feasible_on(&node.name, &disc.devices) {
continue;
}
let pinned = profile.pinned_on.iter().any(|n| n == &node.name);
candidates.push((node.name.clone(), node.endpoint.clone(), pinned));
}
candidates.sort_by(|a, b| {
b.2.cmp(&a.2) // pinned first (true > false)
.then(a.0.cmp(&b.0))
});
let pick = candidates.into_iter().next();
pick.map(|(n, e, _)| (n, e))
.ok_or_else(|| RouteError::NoFeasibleNeuron {
model_id: profile.id.clone(),
})
}
/// Issue `POST {endpoint}/models/load` for this profile on this neuron,
/// blocking until the load completes (neuron's load endpoint is
/// synchronous — it returns 200 once VRAM is materialised). On success
/// also inserts a `Loaded` entry into the local NodeState cache so the
/// caller's subsequent endpoint lookup sees the new model without
/// waiting for the next poll cycle.
async fn cold_load(
fleet: &Arc<CortexState>,
node_name: &str,
neuron_endpoint: &str,
profile: &ModelProfile,
) -> Result<(), RouteError> {
let spec = profile_to_spec(fleet, node_name, profile).await;
let url = format!("{neuron_endpoint}/models/load");
tracing::info!(model = %profile.id, node = node_name, "cold-loading via /models/load");
// Generous timeout: a fresh download + safetensors mmap + device
// copy for a 30B-class dense model can comfortably exceed 5 min on
// a slow link. The HTTP client's own default already covers most
// of this; pin a longer per-request bound just here.
let resp = match fleet
.http_client
.post(&url)
.timeout(Duration::from_secs(1800))
.json(&spec)
.send()
.await
{
Ok(r) => r,
Err(e) => {
return Err(RouteError::ColdLoadFailed {
model_id: profile.id.clone(),
node: node_name.to_string(),
message: format!("HTTP request failed: {e}"),
});
}
};
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
// Neuron returns 400 "already loaded" when two concurrent
// requests race the same model. Treat that as success — both
// requests effectively achieved the same end state.
if body.contains("already loaded") {
tracing::info!(
model = %profile.id,
node = node_name,
"cold-load saw 'already loaded' — treating as success"
);
} else {
return Err(RouteError::ColdLoadFailed {
model_id: profile.id.clone(),
node: node_name.to_string(),
message: format!("HTTP {status}: {body}"),
});
}
} else {
tracing::info!(model = %profile.id, node = node_name, "cold-load returned 200");
}
// Warm the cache: insert a Loaded ModelEntry so the next
// resolve() finds the model without waiting for the poll loop.
{
let mut nodes = fleet.nodes.write().await;
if let Some(node) = nodes.get_mut(node_name) {
node.models.insert(
profile.id.clone(),
cortex_core::node::ModelEntry {
id: profile.id.clone(),
status: ModelStatus::Loaded,
last_accessed: Some(chrono::Utc::now()),
vram_estimate_mb: profile.vram_mb,
capabilities: Vec::new(),
},
);
}
}
Ok(())
}
/// Translate a `ModelProfile` to a `ModelSpec` neuron's /models/load
/// accepts. Devices are picked from the neuron's discovered topology —
/// the first `min_devices` indices that meet `min_device_vram_mb`.
async fn profile_to_spec(
fleet: &Arc<CortexState>,
node_name: &str,
profile: &ModelProfile,
) -> ModelSpec {
let devices = {
let nodes = fleet.nodes.read().await;
let mut picked: Vec<u32> = Vec::new();
if let Some(node) = nodes.get(node_name)
&& let Some(disc) = &node.discovery
{
let min_vram = profile.min_device_vram_mb.unwrap_or(0);
for d in &disc.devices {
if d.vram_total_mb >= min_vram {
picked.push(d.index);
if picked.len() as u32 >= profile.min_devices {
break;
}
}
}
}
if picked.is_empty() {
// Fall back to a 0..min_devices default; pick_feasible_neuron
// already verified the topology satisfies the constraints,
// so this only fires if discovery raced or was lost.
(0..profile.min_devices).collect()
} else {
picked
}
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
if nodes.values().any(|n| n.healthy) {
RouteError::ModelNotFound(model_id.to_string())
} else {
RouteError::NoHealthyNodes
}
})?
};
let tensor_parallel = if profile.min_devices > 1 {
Some(profile.min_devices)
} else {
None
};
ModelSpec {
model_id: qualified_model_id(profile),
harness: profile.harness.clone(),
quant: profile.quant.clone(),
tensor_parallel,
devices: Some(devices),
}
}
/// Prefix the catalogue id with the scheme when one is declared, so
/// neuron resolves the load against the right registry. Without this,
/// a profile pointing at the helexa registry would resolve via
/// neuron's `default_source` (typically `huggingface`) and fetch
/// bytes from the wrong place. Profiles that omit `source` continue
/// to pass the bare id through, preserving the pre-Phase-3 contract.
///
/// Stays at module scope (not nested in `profile_to_spec`) so the unit
/// tests can exercise it without spinning up CortexState topology.
fn qualified_model_id(profile: &ModelProfile) -> String {
match profile.source.as_deref() {
Some(scheme) if !scheme.is_empty() => format!("{scheme}:{}", profile.id),
_ => profile.id.clone(),
}
}
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
/// build the final `RouteDecision`. Shared by all three priority
/// branches above.
async fn finish(
fleet: &Arc<CortexState>,
node_name: &str,
neuron_endpoint: &str,
model_id: &str,
cold_start: bool,
) -> Result<RouteDecision, RouteError> {
// Ask the neuron for the inference endpoint for this model.
let endpoint_url = format!(
"{}/models/{}/endpoint",
neuron_endpoint,
@@ -380,119 +89,13 @@ async fn finish(
_ => None,
};
let raw = inference_endpoint.ok_or_else(|| {
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.to_string())
let endpoint = inference_endpoint.ok_or_else(|| {
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone())
})?;
// Rewrite loopback inference URLs to use the configured neuron host.
// Neuron's default bind_url is `http://localhost:13131` (it can't
// reliably know its own externally-resolvable name). Cortex sees a
// URL that's only meaningful from the neuron host's own perspective;
// proxying directly to localhost from a different cortex host would
// hit nothing. Keep neuron's port and path (a future harness could
// serve inference on a different port than the management API), but
// swap the host for the one in cortex.toml.
let endpoint = rewrite_loopback_host(&raw, neuron_endpoint).unwrap_or(raw);
Ok(RouteDecision {
node_name: node_name.to_string(),
node_name,
endpoint,
cold_start,
resolved_model_id: model_id.to_string(),
})
}
/// If `inference_url`'s host is a loopback name (localhost / 127.0.0.1 /
/// 0.0.0.0 / ::1), return a copy with the host replaced by
/// `neuron_endpoint`'s host. Otherwise return None and the caller falls
/// back to the inference URL as-is.
fn rewrite_loopback_host(inference_url: &str, neuron_endpoint: &str) -> Option<String> {
let inf = url::Url::parse(inference_url).ok()?;
let inf_host = inf.host_str()?;
let is_loopback = matches!(inf_host, "localhost" | "127.0.0.1" | "0.0.0.0" | "::1");
if !is_loopback {
return None;
}
let neuron = url::Url::parse(neuron_endpoint).ok()?;
let new_host = neuron.host_str()?;
let mut out = inf.clone();
out.set_host(Some(new_host)).ok()?;
// url::Url::to_string normalises an empty path to "/", which then
// breaks downstream callers that do format!("{endpoint}/v1/...")
// and produce a double slash. The proxy URL is treated as a base
// string that the caller appends paths to, so strip the trailing
// slash here.
let s = out.to_string();
Some(s.trim_end_matches('/').to_string())
}
#[cfg(test)]
mod tests {
use super::{ModelProfile, qualified_model_id, rewrite_loopback_host};
fn bare_profile(id: &str, source: Option<&str>) -> ModelProfile {
ModelProfile {
id: id.into(),
harness: "candle".into(),
quant: None,
vram_mb: None,
min_devices: 1,
min_device_vram_mb: None,
pinned_on: vec![],
source: source.map(String::from),
}
}
#[test]
fn qualified_id_passes_through_when_source_absent() {
let p = bare_profile("Qwen/Qwen3-30B", None);
assert_eq!(qualified_model_id(&p), "Qwen/Qwen3-30B");
}
#[test]
fn qualified_id_prefixes_when_source_set() {
let p = bare_profile("Helexa/Qwen3.6-27B-Uncensored", Some("helexa"));
assert_eq!(
qualified_model_id(&p),
"helexa:Helexa/Qwen3.6-27B-Uncensored"
);
}
#[test]
fn qualified_id_passes_through_when_source_is_empty_string() {
// An empty scheme is treated as absent — neuron's default_source
// substitution kicks in.
let p = bare_profile("Qwen/Qwen3-30B", Some(""));
assert_eq!(qualified_model_id(&p), "Qwen/Qwen3-30B");
}
#[test]
fn rewrites_localhost_keeps_port_and_path() {
let out = rewrite_loopback_host(
"http://localhost:13131",
"http://beast.hanzalova.internal:13131",
);
assert_eq!(
out.as_deref(),
Some("http://beast.hanzalova.internal:13131")
);
}
#[test]
fn rewrites_loopback_with_distinct_inference_port() {
let out = rewrite_loopback_host("http://127.0.0.1:8080", "http://beast.lan:13131");
assert_eq!(out.as_deref(), Some("http://beast.lan:8080"));
}
#[test]
fn leaves_non_loopback_alone() {
let out = rewrite_loopback_host("http://other.host:1234", "http://beast.lan:13131");
assert_eq!(out, None);
}
#[test]
fn malformed_inference_url_returns_none() {
let out = rewrite_loopback_host("not a url", "http://beast.lan:13131");
assert_eq!(out, None);
}
}

View File

@@ -26,8 +26,6 @@ impl CortexState {
models: HashMap::new(),
lifecycle_cycles: 0,
last_poll: None,
discovery: None,
activation: None,
},
);
}

View File

@@ -1,268 +0,0 @@
//! Alias resolution: a client request with `model: "helexa/small"`
//! routes to the concrete model id (e.g. `Qwen/Qwen3-1.7B`), with the
//! proxied request body rewritten so the upstream neuron sees a model
//! name that matches its loaded handle.
mod common;
use cortex_core::config::{
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
};
use cortex_core::node::{ModelEntry, ModelStatus};
use cortex_gateway::state::CortexState;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::net::TcpListener;
/// Write a `models.toml` with one alias to a unique temp path. Returns
/// the path; the file persists for the test process and gets reaped by
/// the OS at exit. Using $XDG_RUNTIME_DIR fallback for the temp dir
/// keeps the file off shared /tmp on CI without pulling in tempfile.
fn write_models_toml(alias: &str, target: &str) -> PathBuf {
let contents = format!(
r#"
[aliases]
"{alias}" = "{target}"
"#
);
let mut path = std::env::temp_dir();
let pid = std::process::id();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
path.push(format!("cortex-test-models-{pid}-{now}.toml"));
std::fs::write(&path, contents).expect("write temp models.toml");
path
}
#[tokio::test]
async fn test_alias_resolves_in_chat_completions() {
let mock_url = common::spawn_mock_neuron().await;
let models_path = write_models_toml("helexa/small", "test-model");
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![NeuronEndpoint {
name: "mock-node".into(),
endpoint: mock_url,
}],
models_config: models_path.to_string_lossy().to_string(),
};
let fleet = Arc::new(CortexState::from_config(&config));
// Seed the node as healthy with the concrete model loaded under
// the target id. The poller doesn't run in this test; we just
// populate state manually.
{
let mut nodes = fleet.nodes.write().await;
let node = nodes.get_mut("mock-node").expect("node must exist");
node.healthy = true;
node.models.insert(
"test-model".into(),
ModelEntry {
id: "test-model".into(),
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: None,
capabilities: Vec::new(),
},
);
}
// Sanity: the catalogue actually picked up the alias.
assert_eq!(
fleet.catalogue.resolve_alias("helexa/small"),
"test-model",
"alias should resolve to target id"
);
// Spawn the gateway against this fleet.
let app = cortex_gateway::build_app(Arc::clone(&fleet));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let gateway_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let gateway_url = format!("http://{gateway_addr}");
// Send a chat completion against the alias. The mock backend
// echoes back the `model` field it received — so a body whose
// model wasn't rewritten would come back as "helexa/small", and a
// properly-rewritten one as "test-model".
let client = reqwest::Client::new();
let resp = client
.post(format!("{gateway_url}/v1/chat/completions"))
.json(&json!({
"model": "helexa/small",
"messages": [{"role": "user", "content": "hi"}],
}))
.send()
.await
.expect("gateway should respond");
assert!(resp.status().is_success(), "gateway returned non-2xx");
let body: serde_json::Value = resp.json().await.expect("response is JSON");
assert_eq!(
body.get("model").and_then(|m| m.as_str()),
Some("test-model"),
"mock backend should have seen the resolved model id, not the alias"
);
}
#[tokio::test]
async fn test_aliases_surface_in_v1_models() {
let mock_url = common::spawn_mock_neuron().await;
let models_path = write_models_toml("helexa/small", "test-model");
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![NeuronEndpoint {
name: "mock-node".into(),
endpoint: mock_url,
}],
models_config: models_path.to_string_lossy().to_string(),
};
let fleet = Arc::new(CortexState::from_config(&config));
// Seed the target as loaded so the alias's mirrored entry shows
// loaded=true.
{
let mut nodes = fleet.nodes.write().await;
let node = nodes.get_mut("mock-node").expect("node must exist");
node.healthy = true;
node.models.insert(
"test-model".into(),
ModelEntry {
id: "test-model".into(),
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: Some(2000),
capabilities: Vec::new(),
},
);
}
let app = cortex_gateway::build_app(Arc::clone(&fleet));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let gateway_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let gateway_url = format!("http://{gateway_addr}");
let resp = reqwest::get(format!("{gateway_url}/v1/models"))
.await
.expect("gateway should respond");
let body: serde_json::Value = resp.json().await.unwrap();
let entries = body
.get("data")
.and_then(|d| d.as_array())
.expect("data array");
// Both the alias and the target should be present.
let ids: Vec<&str> = entries
.iter()
.filter_map(|e| e.get("id").and_then(|v| v.as_str()))
.collect();
assert!(ids.contains(&"test-model"), "target should be listed");
assert!(ids.contains(&"helexa/small"), "alias should be listed");
// The alias's `loaded` flag and locations should mirror the target.
let alias_entry = entries
.iter()
.find(|e| e.get("id").and_then(|v| v.as_str()) == Some("helexa/small"))
.expect("alias entry");
assert_eq!(alias_entry.get("loaded"), Some(&json!(true)));
let locations = alias_entry
.get("locations")
.and_then(|l| l.as_array())
.expect("locations array");
assert_eq!(locations.len(), 1);
assert_eq!(
locations[0].get("node").and_then(|n| n.as_str()),
Some("mock-node")
);
}
#[tokio::test]
async fn test_alias_falls_through_for_unmapped_model() {
// Catalogue has an alias for some-other-thing but the request
// model "test-model" isn't an alias; resolution should be a no-op.
let mock_url = common::spawn_mock_neuron().await;
let models_path = write_models_toml("helexa/large", "definitely-not-loaded");
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![NeuronEndpoint {
name: "mock-node".into(),
endpoint: mock_url,
}],
models_config: models_path.to_string_lossy().to_string(),
};
let fleet = Arc::new(CortexState::from_config(&config));
{
let mut nodes = fleet.nodes.write().await;
let node = nodes.get_mut("mock-node").expect("node must exist");
node.healthy = true;
node.models.insert(
"test-model".into(),
ModelEntry {
id: "test-model".into(),
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: None,
capabilities: Vec::new(),
},
);
}
let app = cortex_gateway::build_app(Arc::clone(&fleet));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let gateway_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let gateway_url = format!("http://{gateway_addr}");
let resp = reqwest::Client::new()
.post(format!("{gateway_url}/v1/chat/completions"))
.json(&json!({
"model": "test-model",
"messages": [{"role": "user", "content": "hi"}],
}))
.send()
.await
.unwrap();
assert!(resp.status().is_success());
let body: serde_json::Value = resp.json().await.unwrap();
assert_eq!(
body.get("model").and_then(|m| m.as_str()),
Some("test-model")
);
}

View File

@@ -123,124 +123,3 @@ async fn test_anthropic_invalid_request() {
assert_eq!(resp.status(), 400);
}
/// #24: a streaming Anthropic request gets a translated Anthropic SSE
/// stream — not raw OpenAI frames. Verifies the full event sequence,
/// text reassembly, and the content type.
#[tokio::test]
async fn test_anthropic_streaming_sse_translation() {
let mock_url =
common::spawn_streaming_mock_neuron(4, std::time::Duration::from_millis(20)).await;
let gw_url = common::spawn_gateway(&mock_url).await;
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/messages"))
.header("content-type", "application/json")
.json(&json!({
"model": "test-model",
"max_tokens": 64,
"stream": true,
"messages": [{"role": "user", "content": "Hi"}]
}))
.send()
.await
.expect("request should succeed");
assert_eq!(resp.status(), 200);
assert!(
resp.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.starts_with("text/event-stream"),
"anthropic stream must be SSE"
);
let body = resp.text().await.expect("stream should complete");
assert!(
!body.contains("chat.completion.chunk"),
"raw OpenAI frames must not leak through:\n{body}"
);
let event_names: Vec<&str> = body
.lines()
.filter_map(|l| l.strip_prefix("event: "))
.collect();
assert_eq!(
event_names,
vec![
"message_start",
"content_block_start",
"content_block_delta",
"content_block_delta",
"content_block_delta",
"content_block_delta",
"content_block_stop",
"message_delta",
"message_stop",
],
"unexpected event sequence:\n{body}"
);
// Reassemble the text deltas: the mock emits token0..token3.
let text: String = body
.lines()
.filter_map(|l| l.strip_prefix("data: "))
.filter_map(|d| serde_json::from_str::<serde_json::Value>(d).ok())
.filter(|v| v["type"] == "content_block_delta")
.filter_map(|v| v["delta"]["text"].as_str().map(String::from))
.collect();
assert_eq!(text, "token0token1token2token3");
// The mock sends no finish_reason — stop_reason defaults to
// end_turn, and output_tokens falls back to the delta count.
let message_delta = body
.lines()
.filter_map(|l| l.strip_prefix("data: "))
.filter_map(|d| serde_json::from_str::<serde_json::Value>(d).ok())
.find(|v| v["type"] == "message_delta")
.expect("message_delta event present");
assert_eq!(message_delta["delta"]["stop_reason"], "end_turn");
assert_eq!(message_delta["usage"]["output_tokens"], 4);
}
/// #24: an upstream usage frame (stream_options include_usage shape)
/// rides into message_delta as input/output token counts.
#[tokio::test]
async fn test_anthropic_streaming_usage_propagation() {
let mock_url = common::spawn_streaming_mock_neuron_with_usage(
3,
std::time::Duration::from_millis(10),
225,
42,
)
.await;
let gw_url = common::spawn_gateway(&mock_url).await;
let client = reqwest::Client::new();
let body = client
.post(format!("{gw_url}/v1/messages"))
.header("content-type", "application/json")
.json(&json!({
"model": "test-model",
"max_tokens": 64,
"stream": true,
"messages": [{"role": "user", "content": "Hi"}]
}))
.send()
.await
.expect("request should succeed")
.text()
.await
.expect("stream should complete");
let message_delta = body
.lines()
.filter_map(|l| l.strip_prefix("data: "))
.filter_map(|d| serde_json::from_str::<serde_json::Value>(d).ok())
.find(|v| v["type"] == "message_delta")
.expect("message_delta event present");
assert_eq!(message_delta["usage"]["output_tokens"], 42);
assert_eq!(message_delta["usage"]["input_tokens"], 225);
}

View File

@@ -22,7 +22,6 @@ use tokio::net::TcpListener;
/// - GET /models/:id/endpoint (returns the inference URL)
/// - POST /models/unload (accepts unload requests)
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
///
/// Returns the neuron base URL.
pub async fn spawn_mock_neuron() -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
@@ -44,7 +43,6 @@ pub async fn spawn_mock_neuron() -> String {
post(|Json(_body): Json<Value>| async { Json(json!({"status": "unloaded"})) }),
)
.route("/v1/chat/completions", post(mock_chat_completions))
.route("/v1/responses", post(mock_responses))
.route("/v1/models", get(mock_v1_models));
tokio::spawn(async move {
@@ -56,7 +54,7 @@ pub async fn spawn_mock_neuron() -> String {
async fn mock_neuron_list_models() -> Json<Value> {
Json(json!([
{"id": "test-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
{"id": "test-model", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
]))
}
@@ -94,39 +92,6 @@ async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
}))
}
async fn mock_responses(Json(body): Json<Value>) -> Json<Value> {
let model = body
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
// Echo the model field back and synthesise a tiny ResponsesResponse.
// Mirrors the shape neuron's /v1/responses handler emits so the
// gateway test only needs to assert the proxy round-tripped it.
Json(json!({
"id": "resp-test-001",
"object": "response",
"created_at": 1700000000_u64,
"status": "completed",
"model": model,
"output": [{
"type": "message",
"id": "msg-test-001",
"role": "assistant",
"content": [{
"type": "output_text",
"text": "Hello from mock backend",
"annotations": []
}],
"status": "completed"
}],
"usage": {
"input_tokens": 5,
"output_tokens": 5,
"total_tokens": 10
}
}))
}
/// Spawns a mock neuron that returns SSE streaming responses for chat completions.
pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Duration) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
@@ -196,120 +161,8 @@ pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Durati
base_url
}
/// Like `spawn_streaming_mock_neuron`, but the stream ends with an
/// OpenAI `stream_options.include_usage`-style final chunk (empty
/// choices + usage object) before `[DONE]` — the shape the gateway's
/// token metrics (#21) extract counts from.
pub async fn spawn_streaming_mock_neuron_with_usage(
chunk_count: usize,
chunk_delay: Duration,
prompt_tokens: u64,
completion_tokens: u64,
) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let base_url = format!("http://{addr}");
let inference_url = base_url.clone();
let app = Router::new()
.route("/models", get(mock_neuron_list_models))
.route(
"/models/{model_id}/endpoint",
get(move |Path(_model_id): Path<String>| {
let url = inference_url.clone();
async move { Json(json!({"url": url})) }
}),
)
.route(
"/v1/chat/completions",
post(move |Json(body): Json<Value>| async move {
let model = body
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let mut chunks: Vec<String> = (0..chunk_count)
.map(|i| {
let chunk = json!({
"id": "chatcmpl-stream-002",
"object": "chat.completion.chunk",
"created": 1700000000_u64,
"model": model,
"choices": [{
"index": 0,
"delta": { "content": format!("token{i}") },
"finish_reason": null
}]
});
format!("data: {chunk}\n\n")
})
.collect();
let usage_chunk = json!({
"id": "chatcmpl-stream-002",
"object": "chat.completion.chunk",
"created": 1700000000_u64,
"model": model,
"choices": [],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
});
chunks.push(format!("data: {usage_chunk}\n\n"));
chunks.push("data: [DONE]\n\n".to_string());
let delay = chunk_delay;
let stream = stream::iter(chunks).then(move |chunk| async move {
tokio::time::sleep(delay).await;
Ok::<_, std::convert::Infallible>(chunk)
});
Response::builder()
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.body(Body::from_stream(stream))
.unwrap()
}),
);
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
base_url
}
/// Spawns a mock neuron with a custom models list.
pub async fn spawn_mock_neuron_with_models(models_response: Value) -> String {
spawn_mock_neuron_with_models_and_health(models_response, default_health_response()).await
}
/// Default `/health` response used by mocks that don't care about the
/// activation field — empty devices, no in-flight pre-warm, state=ready.
pub fn default_health_response() -> Value {
json!({
"uptime_secs": 0,
"devices": [],
"activation": {
"state": "ready",
"pending": [],
"in_progress": null,
"completed": [],
"failed": []
}
})
}
/// Variant of `spawn_mock_neuron_with_models` that also serves a
/// `/health` body. Used by tests that drive the gateway's activation
/// surface (poller reading /health, /v1/models synthesising Loading
/// locations from in_progress / pending).
pub async fn spawn_mock_neuron_with_models_and_health(
models_response: Value,
health_response: Value,
) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let base_url = format!("http://{addr}");
@@ -323,13 +176,6 @@ pub async fn spawn_mock_neuron_with_models_and_health(
async move { Json(resp) }
}),
)
.route(
"/health",
get(move || {
let resp = health_response.clone();
async move { Json(resp) }
}),
)
.route(
"/models/{model_id}/endpoint",
get(move |Path(_model_id): Path<String>| {
@@ -390,7 +236,6 @@ pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, Stri
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: Some(8000),
capabilities: Vec::new(),
},
);
}

View File

@@ -91,7 +91,6 @@ async fn test_evict_lru_model() {
status: ModelStatus::Loaded,
last_accessed: Some(Utc::now() - chrono::Duration::hours(2)),
vram_estimate_mb: Some(8000),
capabilities: Vec::new(),
},
);
node.models.insert(
@@ -101,7 +100,6 @@ async fn test_evict_lru_model() {
status: ModelStatus::Loaded,
last_accessed: Some(Utc::now()),
vram_estimate_mb: Some(8000),
capabilities: Vec::new(),
},
);
}
@@ -165,7 +163,6 @@ async fn test_eviction_increments_lifecycle_cycles() {
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: None,
capabilities: Vec::new(),
},
);
}

View File

@@ -1,26 +1,20 @@
mod common;
use serde_json::json;
use std::sync::OnceLock;
/// The metrics recorder is a process-wide global; both tests in this
/// binary run against one shared install. Assertions must therefore be
/// order-independent (presence of names / monotonic counters, not
/// "empty before").
fn recorder() -> &'static metrics_exporter_prometheus::PrometheusHandle {
static HANDLE: OnceLock<metrics_exporter_prometheus::PrometheusHandle> = OnceLock::new();
HANDLE.get_or_init(|| {
cortex_gateway::metrics::install_test_recorder().expect("recorder should install")
})
}
#[tokio::test]
async fn test_metrics_emitted_after_proxy() {
let handle = recorder();
let handle = cortex_gateway::metrics::install_test_recorder().expect("recorder should install");
let mock_url = common::spawn_mock_neuron().await;
let gw_url = common::spawn_gateway(&mock_url).await;
let before = handle.render();
assert!(
!before.contains("cortex_requests_total"),
"no request metrics before any requests"
);
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/chat/completions"))
@@ -50,72 +44,3 @@ async fn test_metrics_emitted_after_proxy() {
"no errors expected for a successful request"
);
}
#[tokio::test]
async fn test_token_metrics_emitted_for_streamed_request() {
// #21: a streamed chat completion with a final usage chunk must
// produce TTFT + tok/s histograms and prompt/completion token
// counters, labelled with model and node. The recorder is global
// per-process, so this test runs in its own binary invocation —
// cargo's per-file integration binaries give us that as long as
// only one test in this file installs the recorder... it isn't:
// test_metrics_emitted_after_proxy also installs. Whichever wins
// the race, both render from the same recorder, so assert on
// delta-able names rather than exact totals.
let handle = recorder();
let mock_url = common::spawn_streaming_mock_neuron_with_usage(
5,
std::time::Duration::from_millis(40),
225,
42,
)
.await;
let gw_url = common::spawn_gateway(&mock_url).await;
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/chat/completions"))
.header("content-type", "application/json")
.json(&json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hi"}],
"stream": true
}))
.send()
.await
.expect("request should succeed");
assert_eq!(resp.status(), 200);
let body = resp.text().await.expect("stream should complete");
assert!(body.contains("[DONE]"));
let rendered = handle.render();
for needle in [
"cortex_time_to_first_token_seconds",
"cortex_tokens_per_second",
] {
assert!(
rendered.contains(needle),
"{needle} should be present.\nMetrics:\n{rendered}"
);
}
// The recorder is shared with the sibling test (same model/node
// labels), so counters are lower bounds, not exact values: this
// request contributed prompt=225 / completion=42.
let counter_value = |name: &str| -> u64 {
rendered
.lines()
.find(|l| l.starts_with(name) && l.contains(r#"model="test-model""#))
.and_then(|l| l.rsplit(' ').next())
.and_then(|v| v.parse().ok())
.unwrap_or_else(|| panic!("{name} should be present.\nMetrics:\n{rendered}"))
};
assert!(
counter_value("cortex_prompt_tokens_total") >= 225,
"prompt token counter should include this request's 225.\nMetrics:\n{rendered}"
);
assert!(
counter_value("cortex_completion_tokens_total") >= 42,
"completion token counter should include this request's 42.\nMetrics:\n{rendered}"
);
}

View File

@@ -12,8 +12,8 @@ use std::sync::Arc;
async fn test_poller_discovers_models() {
// Mock neuron reports 2 models via /models endpoint (neuron format).
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "model-a", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
{"id": "model-b", "harness": "candle", "status": "unloaded", "devices": [], "vram_used_mb": null}
{"id": "model-a", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
{"id": "model-b", "harness": "mistralrs", "status": "unloaded", "devices": [], "vram_used_mb": null}
]))
.await;
@@ -63,8 +63,8 @@ async fn test_poller_discovers_models() {
#[tokio::test]
async fn test_poller_updates_gateway_models_endpoint() {
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "model-x", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "model-y", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null}
{"id": "model-x", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "model-y", "harness": "mistralrs", "status": "loaded", "devices": [1], "vram_used_mb": null}
]))
.await;
@@ -118,87 +118,6 @@ async fn test_poller_updates_gateway_models_endpoint() {
}
}
#[tokio::test]
async fn test_models_endpoint_unions_capabilities_across_nodes() {
// C3: two neurons each have the same model loaded but advertise
// different capability sets. The gateway's /v1/models must report
// the union — a model loaded text-only on one node and
// text+vision on another is vision-capable to the fleet.
let node_a = common::spawn_mock_neuron_with_models(json!([
{"id": "shared-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null, "capabilities": ["text"]}
]))
.await;
let node_b = common::spawn_mock_neuron_with_models(json!([
{"id": "shared-model", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null, "capabilities": ["text", "vision"]}
]))
.await;
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![
NeuronEndpoint {
name: "node-a".into(),
endpoint: node_a,
},
NeuronEndpoint {
name: "node-b".into(),
endpoint: node_b,
},
],
models_config: "/dev/null".into(),
};
let fleet = Arc::new(CortexState::from_config(&config));
cortex_gateway::poller::poll_once(&fleet).await;
let app = cortex_gateway::build_app(Arc::clone(&fleet));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let client = reqwest::Client::new();
let body: serde_json::Value = client
.get(format!("http://{addr}/v1/models"))
.send()
.await
.expect("request should succeed")
.json()
.await
.unwrap();
let model = body["data"]
.as_array()
.expect("data array")
.iter()
.find(|m| m["id"] == "shared-model")
.expect("shared-model should be present");
let caps: Vec<&str> = model["capabilities"]
.as_array()
.expect("capabilities array")
.iter()
.filter_map(|c| c.as_str())
.collect();
assert!(caps.contains(&"text"), "union must include text: {caps:?}");
assert!(
caps.contains(&"vision"),
"union must include vision: {caps:?}"
);
assert_eq!(caps.len(), 2, "union must not duplicate text: {caps:?}");
// Both nodes hold the model, so two locations regardless of caps.
assert_eq!(model["locations"].as_array().unwrap().len(), 2);
}
#[tokio::test]
async fn test_poller_marks_unreachable_node_unhealthy() {
let config = GatewayConfig {
@@ -233,8 +152,8 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
#[tokio::test]
async fn test_poller_removes_stale_models() {
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "drop-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
{"id": "drop-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
]))
.await;
@@ -264,7 +183,7 @@ async fn test_poller_removes_stale_models() {
// New mock with only one model.
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
]))
.await;
@@ -297,7 +216,6 @@ async fn test_poller_removes_stale_models() {
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: None,
capabilities: Vec::new(),
},
);
node.models.insert(
@@ -307,7 +225,6 @@ async fn test_poller_removes_stale_models() {
status: ModelStatus::Loaded,
last_accessed: None,
vram_estimate_mb: None,
capabilities: Vec::new(),
},
);
}
@@ -320,94 +237,3 @@ async fn test_poller_removes_stale_models() {
assert!(node.models.contains_key("keep-me"));
assert!(!node.models.contains_key("drop-me"));
}
#[tokio::test]
async fn test_poller_captures_activation_from_health() {
// Mock neuron is mid-prewarm: /models reports nothing (the loading
// model hasn't been inserted into the harness map yet), but
// /health's activation says model-x is in_progress and model-y is
// queued behind it.
let mock_url = common::spawn_mock_neuron_with_models_and_health(
json!([]),
json!({
"uptime_secs": 30,
"devices": [],
"activation": {
"state": "pre_warming",
"pending": ["Qwen/model-y"],
"in_progress": "Qwen/model-x",
"completed": [],
"failed": []
}
}),
)
.await;
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![NeuronEndpoint {
name: "prewarm-node".into(),
endpoint: mock_url,
}],
models_config: "/dev/null".into(),
};
let fleet = Arc::new(CortexState::from_config(&config));
cortex_gateway::poller::poll_once(&fleet).await;
let nodes = fleet.nodes.read().await;
let node = nodes.get("prewarm-node").unwrap();
assert!(node.healthy);
// /models was empty — no entries in the per-node model map.
assert!(node.models.is_empty());
// But /health's activation should be captured.
let activation = node
.activation
.as_ref()
.expect("activation should be populated after /health poll");
assert_eq!(activation.in_progress.as_deref(), Some("Qwen/model-x"));
assert_eq!(activation.pending, vec!["Qwen/model-y".to_string()]);
}
#[tokio::test]
async fn test_poller_parses_recovering_status() {
// #20: a model auto-recovering on a neuron (poisoned → unload →
// reload, #17) is reported with status "recovering" and must land
// in gateway state as the dedicated Recovering status — not fall
// through the parser's catch-all to Loaded.
let mock_url = common::spawn_mock_neuron_with_models(json!([
{"id": "model-r", "harness": "candle", "status": "recovering", "devices": [0, 1], "vram_used_mb": null}
]))
.await;
let config = GatewayConfig {
gateway: GatewaySettings {
listen: "127.0.0.1:0".into(),
metrics_listen: "127.0.0.1:0".into(),
},
eviction: EvictionSettings {
strategy: EvictionStrategy::Lru,
defrag_after_cycles: 0,
},
neurons: vec![NeuronEndpoint {
name: "test-node".into(),
endpoint: mock_url,
}],
models_config: "/dev/null".into(),
};
let fleet = Arc::new(CortexState::from_config(&config));
cortex_gateway::poller::poll_once(&fleet).await;
let nodes = fleet.nodes.read().await;
let node = nodes.get("test-node").unwrap();
let model_r = node.models.get("model-r").expect("model-r should exist");
assert_eq!(model_r.status, ModelStatus::Recovering);
}

View File

@@ -171,64 +171,3 @@ async fn test_missing_model_field() {
let body: serde_json::Value = resp.json().await.unwrap();
assert!(body["error"]["message"].as_str().unwrap().contains("model"));
}
#[tokio::test]
async fn test_recovering_model_returns_503_and_stays_listed() {
// #20: while a model auto-recovers on a neuron, the gateway must
// hold the route — transient 503 ("retry shortly"), not the 404
// "not found on any node" that makes a recovering model look
// evicted — and keep listing it on /v1/models.
let mock_url = common::spawn_mock_neuron().await;
let (fleet, gw_url) = common::spawn_gateway_with_state(&mock_url).await;
{
let mut nodes = fleet.nodes.write().await;
let node = nodes.get_mut("mock-node").expect("node must exist");
node.models.insert(
"recovering-model".into(),
cortex_core::node::ModelEntry {
id: "recovering-model".into(),
status: cortex_core::node::ModelStatus::Recovering,
last_accessed: None,
vram_estimate_mb: Some(8000),
capabilities: Vec::new(),
},
);
}
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/chat/completions"))
.header("content-type", "application/json")
.json(&json!({
"model": "recovering-model",
"messages": [{"role": "user", "content": "Hi"}]
}))
.send()
.await
.expect("request should succeed");
assert_eq!(resp.status(), 503);
let body: serde_json::Value = resp.json().await.unwrap();
let message = body["error"]["message"].as_str().unwrap();
assert!(
message.contains("recovering") && message.contains("retry"),
"503 body must say recovering/retry, got: {message}"
);
// The model must still be visible on the unified models endpoint.
let models: serde_json::Value = client
.get(format!("{gw_url}/v1/models"))
.send()
.await
.expect("models request should succeed")
.json()
.await
.unwrap();
let listed = models["data"]
.as_array()
.unwrap()
.iter()
.any(|m| m["id"] == "recovering-model");
assert!(listed, "recovering model must stay listed on /v1/models");
}

View File

@@ -1,91 +0,0 @@
//! Integration tests for the `/v1/responses` proxy route.
//!
//! The gateway forwards the request body to whichever neuron has the
//! model loaded. These tests exercise the routing decision (200 on a
//! known model, 404 on an unknown model, 400 on a missing model
//! field) and confirm the response body round-trips verbatim.
mod common;
use serde_json::json;
/// Happy path: gateway routes a `/v1/responses` request to the neuron
/// that has the model loaded, and the neuron's response body
/// arrives at the client unchanged.
#[tokio::test]
async fn test_responses_proxy() {
let mock_url = common::spawn_mock_neuron().await;
let gw_url = common::spawn_gateway(&mock_url).await;
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/responses"))
.header("content-type", "application/json")
.json(&json!({
"model": "test-model",
"input": "Hi"
}))
.send()
.await
.expect("request should succeed");
assert_eq!(resp.status(), 200);
let body: serde_json::Value = resp.json().await.expect("valid JSON response");
assert_eq!(body["id"], "resp-test-001");
assert_eq!(body["object"], "response");
assert_eq!(body["model"], "test-model");
assert_eq!(body["status"], "completed");
assert_eq!(
body["output"][0]["content"][0]["text"],
"Hello from mock backend"
);
// Usage shape is the Responses-specific (input/output_tokens),
// not the chat-completions one (prompt/completion_tokens). Asserts
// the proxy didn't accidentally route through the wrong handler.
assert_eq!(body["usage"]["total_tokens"], 10);
assert!(body["usage"].get("input_tokens").is_some());
}
/// A request that targets a model not present in the catalogue gets
/// 404 from the router. This matches the chat-completions handler's
/// behaviour — same error path, same status code, so a client can
/// share retry logic across the two routes.
#[tokio::test]
async fn test_responses_model_not_found() {
let mock_url = common::spawn_mock_neuron().await;
let gw_url = common::spawn_gateway(&mock_url).await;
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/responses"))
.json(&json!({
"model": "not-in-catalogue",
"input": "Hi"
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 404);
}
/// A request body without a `model` field can't be routed; the
/// gateway returns 400 before reaching a backend. Same as the
/// chat-completions handler — extracted via the same `extract_model`
/// helper.
#[tokio::test]
async fn test_responses_missing_model_field() {
let mock_url = common::spawn_mock_neuron().await;
let gw_url = common::spawn_gateway(&mock_url).await;
let client = reqwest::Client::new();
let resp = client
.post(format!("{gw_url}/v1/responses"))
.json(&json!({
"input": "Hi"
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 400);
}

View File

@@ -51,18 +51,18 @@ async fn test_streaming_sse_passthrough() {
}
assert!(
chunks.len() > chunk_count,
"expected more than {} chunks (got {}): {:?}",
chunk_count,
chunks.len() >= chunk_count + 1,
"expected at least {} chunks (got {}): {:?}",
chunk_count + 1,
chunks.len(),
chunks,
);
assert_eq!(chunks.last().unwrap(), "[DONE]");
for (i, chunk) in chunks.iter().enumerate().take(chunk_count) {
for i in 0..chunk_count {
let chunk_json: serde_json::Value =
serde_json::from_str(chunk).expect("chunk should be valid JSON");
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
assert_eq!(
chunk_json["choices"][0]["delta"]["content"],
format!("token{i}")

View File

@@ -1,48 +0,0 @@
[package]
name = "helexa-acp"
version = "0.1.16"
edition = "2024"
license = "Apache-2.0"
repository = "https://git.lair.cafe/helexa/helexa"
description = """
Agent Client Protocol bridge for the helexa self-hosted LLM stack.
Speaks ACP to ACP-compatible editor clients (Zed, etc.) and forwards
the conversation to any OpenAI-compatible HTTP endpoint — defaulting
to cortex (helexa's reverse-proxy / fleet gateway).
"""
# This crate is intentionally self-contained — no dependencies on other
# workspace crates (cortex-core, cortex-gateway, neuron). The goal is
# a painless migration to a dedicated GitHub repo in the future if the
# project grows beyond helexa's needs. All deps are crates.io.
[dependencies]
# `unstable_session_model` flips on the SessionModelState type and the
# session/set_model RPC the model-picker dropdown in Zed needs. The
# feature is upstream-marked unstable; we accept that risk because the
# model picker is core UX and the alternative (rolling our own
# extension method) drifts further from spec each time it moves.
agent-client-protocol = { version = "0.12", features = ["unstable_session_model"] }
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "io-util", "process", "signal"] }
reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"], default-features = false }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
anyhow = "1"
thiserror = "2"
async-trait = "0.1"
futures = "0.3"
tokio-stream = "0.1"
tokio-util = { version = "0.7", features = ["rt"] }
eventsource-stream = "0.2"
async-stream = "0.3"
url = { version = "2", features = ["serde"] }
# Already transitively pulled via the ACP SDK; declared directly so we
# can format ISO 8601 timestamps for `SessionInfo.updated_at` in the
# session/list response.
chrono = { version = "0.4", default-features = false, features = ["std"] }
[[bin]]
name = "helexa-acp"
path = "src/main.rs"

View File

@@ -1,546 +0,0 @@
# helexa-acp
ACP (Agent Client Protocol) bridge for editors like
[Zed](https://zed.dev). Lets you point your editor's agent panel at
**any combination** of OpenAI-compatible, OpenAI Responses, and
Anthropic Messages endpoints — public APIs, private LAN deployments,
local Ollama / LM Studio — and switch between them per session via a
model dropdown.
The "missing ACP binary" for users who don't want to be locked into
one vendor's agent client.
```
┌───────────────────────────────────┐
│ Zed (or any ACP editor client) │
└────────────┬──────────────────────┘
│ stdio JSON-RPC (ACP)
┌─────────────────┐
│ helexa-acp │ ← one binary, multi-endpoint
└─────┬───────────┘
│ HTTP / SSE
┌────────┼─────────────┬──────────────┬──────────────┐
▼ ▼ ▼ ▼ ▼
cortex/ OpenAI Anthropic OpenRouter LM Studio
neuron Responses Messages
(self- (gpt-5,…) (Claude)
hosted)
```
## What it does
- **Speaks ACP** over stdio to editor clients (Zed today; any future
ACP client tomorrow).
- **Multi-endpoint** — one config file lists every LLM endpoint
you want available; pick one per session via the model dropdown
(`endpoint:model` selector).
- **Three wire formats**: `openai-chat` (the broadly compatible
default), `openai-responses` (newer OpenAI surface), and
`anthropic-messages` (Claude). Each is a separate provider impl
in `src/provider/`; adding a fourth (Gemini, Ollama native, …) is
one file plus a `WireApi` enum variant.
- **Built-in tools**: `read_file`, `write_file`, `edit_file`,
`list_dir`, `bash`. Permission-gated by default; the editor user
approves writes/shell per-call.
- **Three session modes**: Default (gated), Bypass Permissions
(auto-allow), and Plan (write-only-to-plan-dir, no shell).
- **Vision** — drag-drop images into the agent panel against any
vision-capable model.
- **Session resume** — multi-day conversations survive editor
restarts via on-disk transcript persistence.
- **Context compaction** — rolling history stays inside the model's
context window automatically so long sessions on small-context
local models don't fall over.
## Install
### From source
```sh
git clone https://git.lair.cafe/helexa/helexa.git
cd helexa
cargo install --path crates/helexa-acp
# Binary lands at ~/.cargo/bin/helexa-acp
```
### Pre-built RPM (Fedora 43)
```sh
dnf copr enable helexa/helexa
dnf install helexa-acp
```
The COPR project bundles helexa-acp alongside the cortex gateway
and helexa-neuron flavours; install only the package(s) you need.
## Quick start
The fastest path: env-var single-endpoint config.
```sh
export HELEXA_ACP_BASE_URL=http://hanzalova.internal:31313/v1
export HELEXA_ACP_MODEL=Qwen/Qwen3.6-27B
helexa-acp # speaks ACP over stdin/stdout; not interactive
```
Then in Zed (`~/.config/zed/settings.json`):
```jsonc
{
"agent_servers": {
"helexa": {
"command": "helexa-acp",
"args": []
}
}
}
```
Restart Zed → open the agent panel → pick "helexa" → start
chatting. Tool calls (file reads, writes, bash) prompt for
permission per-call in Default mode.
That's the minimum. The full config story below is what unlocks
the multi-endpoint dropdown.
## Multi-endpoint config
Copy `helexa-acp.example.toml` from this repo to
`$XDG_CONFIG_HOME/helexa-acp/config.toml` (typically
`~/.config/helexa-acp/config.toml`) and edit:
```toml
default_endpoint = "helexa"
[[endpoints]]
name = "helexa"
base_url = "http://hanzalova.internal:31313/v1"
wire_api = "openai-chat"
default_model = "Qwen/Qwen3.6-27B"
max_tokens = 8192
context_window = 32768
[[endpoints]]
name = "openrouter"
base_url = "https://openrouter.ai/api/v1"
wire_api = "openai-chat"
api_key_env = "OPENROUTER_API_KEY"
default_model = "anthropic/claude-opus-4"
[[endpoints]]
name = "anthropic"
base_url = "https://api.anthropic.com/v1"
wire_api = "anthropic-messages"
api_key_env = "ANTHROPIC_API_KEY"
default_model = "claude-opus-4"
```
Restart Zed. The model dropdown lists every model from every
configured endpoint with the `endpoint:model` selector
(`helexa:Qwen/Qwen3.6-27B`, `openrouter:anthropic/claude-opus-4`,
…). Switch mid-session; the next prompt routes to the new endpoint.
When only one endpoint is configured the prefix is dropped (model
ids appear bare).
### Selector syntax
The `model` field on every internal request is parsed as
`<endpoint>:<model>`:
- `openrouter:gpt-4o` → routes to the `openrouter` endpoint,
model `gpt-4o`.
- `helexa/large` → no colon → falls through to whichever endpoint
is named in `default_endpoint`, model `helexa/large`.
- `:gpt-5` → leading colon → also falls through to default.
## Endpoint cookbook
Copy-pasteable blocks. Mix and match.
### cortex / neuron (self-hosted)
```toml
[[endpoints]]
name = "helexa"
base_url = "http://hanzalova.internal:31313/v1"
wire_api = "openai-chat"
default_model = "Qwen/Qwen3.6-27B"
max_tokens = 8192
context_window = 32768
```
Use `openai-responses` instead of `openai-chat` once cortex 0.1.16+
is deployed and you want the Responses API surface (vision item
shape, structured reasoning items, etc.).
### OpenAI directly
```toml
[[endpoints]]
name = "openai"
base_url = "https://api.openai.com/v1"
wire_api = "openai-responses"
api_key_env = "OPENAI_API_KEY"
default_model = "gpt-5"
```
`openai-responses` is the right choice for current OpenAI models;
`openai-chat` works against legacy GPT-3.5/4 deployments and
anything labelled "chat completions".
### Anthropic directly
```toml
[[endpoints]]
name = "anthropic"
base_url = "https://api.anthropic.com/v1"
wire_api = "anthropic-messages"
api_key_env = "ANTHROPIC_API_KEY"
default_model = "claude-opus-4"
```
helexa-acp sends `x-api-key` + `anthropic-version: 2023-06-01`
automatically. The `api_key_env` indirection keeps your key out of
the config file.
### OpenRouter (multi-vendor proxy)
```toml
[[endpoints]]
name = "openrouter"
base_url = "https://openrouter.ai/api/v1"
wire_api = "openai-chat"
api_key_env = "OPENROUTER_API_KEY"
default_model = "anthropic/claude-opus-4"
```
OpenRouter speaks OpenAI-compat for every model it fronts, so
`openai-chat` is the right wire format regardless of the
underlying vendor.
### LM Studio (local)
```toml
[[endpoints]]
name = "lmstudio"
base_url = "http://localhost:1234/v1"
wire_api = "openai-chat"
default_model = "auto"
```
LM Studio's "auto" model id picks whatever's loaded. Same shape
works for Ollama in compat mode (`http://localhost:11434/v1`) and
vLLM.
### Multiple cortex deployments
```toml
[[endpoints]]
name = "lan"
base_url = "http://hanzalova.internal:31313/v1"
wire_api = "openai-chat"
default_model = "Qwen/Qwen3.6-27B"
[[endpoints]]
name = "cloud"
base_url = "https://cortex.example.com/v1"
wire_api = "openai-chat"
api_key_env = "CLOUD_CORTEX_KEY"
default_model = "Qwen/Qwen3-VL-8B"
```
Use the `endpoint:model` selector to switch between them mid-session.
## Zed setup
`~/.config/zed/settings.json`:
```jsonc
{
"agent_servers": {
"helexa": {
"command": "helexa-acp"
}
}
}
```
Optional environment overrides for the binary:
```jsonc
{
"agent_servers": {
"helexa": {
"command": "helexa-acp",
"env": {
"HELEXA_ACP_LOG_FILE": "/tmp/helexa-acp.log",
"RUST_LOG": "helexa_acp=debug"
}
}
}
}
```
`HELEXA_ACP_LOG_FILE` is the one you actually want — Zed doesn't
surface the agent's stderr, so without that env var debug output is
invisible. Point it at a file you can `tail -f`.
After restarting Zed: ⌘+? (or wherever your "Open Agent Panel"
binding is) → select "helexa" → the model dropdown populates from
your config → start prompting.
## Modes
Three session modes ship; the user picks via Zed's mode dropdown
on the agent panel.
| Mode | Reads | Writes | Bash | Permission prompts |
|------|-------|--------|------|--------------------|
| **Default** | ✓ | with prompt | with prompt | per call |
| **Bypass Permissions** | ✓ | ✓ | ✓ | never |
| **Plan** | ✓ | only into plan dir | disabled | never (plan-dir writes auto-allow) |
### Default
Reads are always allowed (`read_file`, `list_dir` are
unrestricted). Writes and shell commands prompt the user before
running. The intended baseline for any session where the agent
might do something you'd rather review first.
### Bypass Permissions
Auto-allow every tool call. Use for agentic loops you trust — bulk
edits across many files, scripted workflows, prepared session
templates. Never for code the agent hasn't seen before.
### Plan
The "draft an implementation plan before you write code" mode.
Available tools:
- `read_file`, `list_dir`: unrestricted (read the codebase).
- `write_file`, `edit_file`: allowed *only* under
`$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. Any path
outside that returns "plan mode: writes are restricted to …"
back to the model so it self-corrects.
- `bash`: disabled outright. Returns "plan mode: shell execution
is disabled" if attempted.
When the plan is complete, the model presents a 3-option menu:
1. **Bypass Permissions** — implement the plan now, no prompts.
2. **Default** — implement now with per-tool prompts.
3. **Plan** (stay here) — refine the plan with more guidance.
Switch the mode dropdown to your preference and reply to proceed.
## Tools
Five tools, defined in `src/tools.rs`:
| Tool | Args | Gated in Default? |
|------|------|-------------------|
| `read_file` | `path`, `line?`, `limit?` | no |
| `list_dir` | `path` | no |
| `write_file` | `path`, `content` | yes |
| `edit_file` | `path`, `old_text`, `new_text` | yes |
| `bash` | `command`, `cwd?` | yes |
### Path handling
`~`, `~/`, `$HOME`, and `$HOME/` are expanded server-side before
the path reaches ACP or local fs. Lets the model emit
`~/git/repo/file.rs` and have it Just Work.
`read_file` first tries the editor's filesystem (ACP's
`fs/read_text_file` — respects open buffers, workspace overlays,
etc.). If that fails — typically because the path is outside Zed's
workspace boundary — it falls back to `std::fs::read_to_string`.
This lets the agent pull in shared material like
`~/git/architecture/generic.md` from a different project's
session.
The fallback is logged at warn level so you can see when it kicks
in.
### Tool dispatch
Tool descriptions reach the model through a Qwen3 Hermes-format
`# Tools` block injected into the system prompt — cortex/neuron
pass the OpenAI `tools` request field through to the encoder
unread, so we work the model into emitting `<tool_call>{json}</tool_call>`
markers it then parses out of the content stream. This applies to
the helexa wire format; OpenAI / Anthropic endpoints with native
tool support would use their own paths once they're wired in.
The parser is tolerant: malformed JSON (trailing braces, missing
`name`, name nested in `arguments`) gets a repair pass; if that
fails the call surfaces as a "Malformed tool call" card in Zed and
the model gets a synthetic error result so it can self-correct.
## Session resume
helexa-acp persists every session to
`$XDG_DATA_HOME/helexa-acp/sessions/<id>.json`. Zed's `session/list`
RPC asks helexa-acp to enumerate them on workspace open;
`session/load` rehydrates and replays the transcript as
`session/update` notifications so the agent panel renders the
prior conversation.
Behaviour:
- Persisted per-round, so a mid-turn agent stall (long bash, wedged
ACP roundtrip) doesn't lose earlier rounds.
- Survives editor restart and the helexa-acp binary upgrading
between versions.
- Project-scoped: only sessions whose `cwd` matches the workspace
are listed.
To wipe history: `rm -rf $XDG_DATA_HOME/helexa-acp/sessions/`.
## Context compaction
When an endpoint sets `context_window`, helexa-acp projects the
rolling history into a token budget before each request — old
`ToolResult` content (read_file payloads are the worst offenders)
gets elided to one-line markers, preserving `tool_call_id` pairing
so the wire schema stays valid.
System prompts, user turns, and the most recent ~4 messages are
never elided. The full history stays on disk; compaction is a
per-request projection, not a destructive edit.
Set `context_window = 32768` for a 32 K Qwen3, `131072` for a
modern Claude, etc. With `max_tokens` also set, the budget is
`context_window - max_tokens - 512_safety`.
## Troubleshooting
### "default endpoint 'helexa' has no usable provider — check config"
The named default endpoint failed to construct. Usually:
- `api_key_env` references a variable that isn't set in the env
Zed launched helexa-acp with.
- The TOML's `wire_api` is misspelled (only `openai-chat`,
`openai-responses`, `anthropic-messages` are accepted).
Test by running `helexa-acp` directly from a shell — startup
errors land on stderr.
### Model dropdown is empty
Each provider's `list_models` failed at startup. Look at
`HELEXA_ACP_LOG_FILE` for "list_models failed; this endpoint's
models won't appear in the picker". Likely the endpoint URL is
wrong, the API key is invalid, or the upstream `/v1/models`
endpoint isn't responding.
The agent still works against `default_model` even when the
dropdown is empty — list-models is for picking, not routing.
### "prompt_too_long" / agent stalls mid-conversation
You hit the model's context window. Set `context_window` on the
endpoint and helexa-acp will compact before sending. The log line
`context compaction applied` confirms it's running; if it fires
but the upstream still rejects, the compaction heuristic
under-counted and the budget needs tuning down.
### Reading files outside the workspace returns "not found"
Zed's `fs/read_text_file` is workspace-scoped. helexa-acp falls
back to local `std::fs` automatically when that fails — look for
`fs/read_text_file failed; falling back to local std::fs` in the
log. If even local read fails, the file genuinely doesn't exist
or the user process lacks permissions.
### Tool calls render as text instead of structured cards
The model is emitting `<tool_call>` markers that the parser can't
decode. Two common causes:
1. The system prompt isn't reaching the model (cortex/neuron's
tool-block injection didn't fire). Confirm with
`RUST_LOG=helexa_acp=debug` and look at the outgoing
`POST /chat/completions` body.
2. The model itself is too small / undertrained to follow the
Hermes format reliably. helexa-acp has shape-based name
inference and JSON repair, but there's a floor below which
nothing helps.
### Plan-mode writes refused even inside the plan dir
The path comparison is byte-for-byte. If the model emits a path
with `~` and the plan_dir has the expanded form, expansion runs
*before* the comparison — but resolved-vs-symlinked-path
mismatches can still bite. The error message names the attempted
path and the expected prefix so you can compare directly.
## Architecture
Source layout under `crates/helexa-acp/src/`:
| File | Responsibility |
|------|----------------|
| `main.rs` | tokio + Stdio transport. Builds providers, hands off to `agent::Agent` |
| `config.rs` | TOML + env-fallback config, endpoint resolver |
| `agent.rs` | ACP handlers (initialize, session/new, session/prompt, session/cancel, session/set_mode, session/set_model, session/load, session/list), prompt loop with tool-call recursion |
| `session.rs` | Per-session state map (Arc<RwLock<HashMap<…>>>) |
| `store.rs` | On-disk session persistence, plan-dir resolution |
| `prompt.rs` | System-prompt assembly, plan-mode addendum |
| `tools.rs` | Tool schemas + shape-based name inference |
| `tool_runner.rs` | Dispatch a single tool call through ACP client RPCs; permission gate |
| `qwen3.rs` | Qwen3 Hermes tool-format parser (`<tool_call>` / `<think>` markers) |
| `compaction.rs` | Token-budget compaction for the rolling history |
| `path_util.rs` | `~` / `$HOME` expansion shared across every path-taking tool |
| `provider/openai_chat.rs` | OpenAI chat completions provider |
| `provider/openai_responses.rs` | OpenAI Responses API provider |
| `provider/anthropic_messages.rs` | Anthropic Messages API provider |
### Adding a new wire format
1. New file under `src/provider/` implementing the `Provider`
trait (encoder + SSE decoder).
2. Add a `WireApi` variant in `config.rs`.
3. Wire it into `build_provider` in `main.rs`.
4. Done — every other module is wire-format-agnostic.
### Concurrency
- `Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>`
per-session mutex so concurrent requests across sessions don't
contend; the map's RwLock is read-mostly.
- Every tool call dispatched serially within a session (parallel
dispatch would require Zed to handle interleaved permission
prompts).
- Provider streams are back-pressured by the consumer (bounded
mpsc channels).
### Self-contained
The crate has no workspace-internal dependencies (no
`cortex-core`, no `cortex-gateway`). Migration to a dedicated
GitHub repo for cross-platform CI / cargo-dist binaries is
Cargo.toml-only.
## Status
- Stages 16 shipped: scaffold, agent loop, tools, modes, session
resume, image input, model picker, three wire formats.
- Stage 8 (RPM + multi-platform CI) tracked in the canonical plan;
Linux x86_64 RPM ships today via the cortex monorepo's Gitea
Actions.
## Contributing
Repository: https://git.lair.cafe/helexa/helexa (`crates/helexa-acp/`).
Issues / PRs welcome. The canonical staged plan is in
`~/.claude/plans/plan-the-per-device-worker-abstract-micali.md` on
the maintainer's machine; the substages 3a3e and 6a/6b that the
canonical plan didn't anticipate are documented in commit messages.
CI: `cargo fmt --check --all`, `cargo clippy --workspace -- -D
warnings`, `cargo test --workspace` must all pass before merge.

File diff suppressed because it is too large Load Diff

View File

@@ -1,425 +0,0 @@
//! Rolling-conversation compaction for small-context local models.
//!
//! The tool-call loop in [`crate::agent`] grows the message vec it
//! sends upstream every round. On a frontier model that's fine; on a
//! 32 K Qwen3 the first few `read_file` results can push the prompt
//! past the model's context window, at which point cortex/neuron
//! refuses with `prompt_too_long` and the whole turn dies. Long-form
//! local agents are unusable without something here.
//!
//! Strategy (intentionally simple — no LLM-summarization round-trip,
//! no tokenizer dependency):
//!
//! 1. **Protect** the things the model cannot reason without:
//! - The system prompt (idx 0).
//! - Every `Role::User` turn (the user's intent — irreplaceable).
//! - The last [`KEEP_TAIL`] messages (most recent rounds stay
//! verbatim so the model can keep working on what it just
//! observed).
//! 2. **Elide** older `Role::Assistant` prose and older `Role::Tool`
//! result content. The structure stays — `tool_call_id`s, tool
//! names, and argument JSON survive intact — so OpenAI's strict
//! `tool_calls` ↔ `tool` pairing schema remains satisfied. Only
//! the *payload* shrinks to a one-line marker.
//! 3. Walk oldest→newest, recomputing the budget after each elision.
//! Stop as soon as we fit; we don't compact more than necessary.
//! 4. If we still exceed budget after eliding everything we're
//! allowed to, return what we have. The upstream will surface a
//! `prompt_too_long` error and the user can intervene; that's
//! better than silently dropping content the model needs.
//!
//! Token estimation uses a `chars / 3.5` heuristic — conservative
//! (over-estimates tokens slightly) so we compact a touch early
//! rather than a touch late.
use crate::provider::{Message, MessageContent, MessagePart, Role};
/// Most-recent N messages that are never elided. Roughly "the
/// current tool round in flight" — assistant turn that called the
/// tools + each tool result + a bit of slack.
const KEEP_TAIL: usize = 4;
/// Below this content size we don't bother eliding — the savings
/// don't outweigh the loss of detail. Roughly 6080 tokens.
const ELIDE_MIN_CHARS: usize = 256;
/// Roughly tokens-per-character for English + code mixed in. The
/// actual per-tokenizer ratio varies (GPT-4o ≈ 4 chars/token on
/// English prose, ≈ 3 chars/token on code-heavy text). We pick a
/// value on the conservative end so the budget check fires *before*
/// the upstream tokenizer says no.
const CHARS_PER_TOKEN: f32 = 3.5;
/// Per-message envelope overhead (role + JSON framing). Comes out
/// to a few tokens; tiny but it adds up across long histories.
const ENVELOPE_TOKENS: usize = 8;
/// Rough per-image token cost used by the budget estimator. Real
/// vision tokenizers vary widely (2561024 tokens for typical
/// resolutions on Qwen3-VL, OpenAI's `low`/`high` detail toggles
/// pick between ~85 and ~1000+). 512 is a defensible middle that
/// keeps compaction from treating images as free.
const IMAGE_TOKENS_APPROX: usize = 512;
/// Stats reported back from [`compact_to_budget`] for the caller to
/// log. The numbers are estimates (see [`estimate_tokens`]), so
/// don't compare them to upstream-reported token counts as if they
/// were exact.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct CompactionStats {
/// Estimated tokens in the input messages.
pub original_tokens: usize,
/// Estimated tokens after compaction. Equal to `original_tokens`
/// when no compaction was needed.
pub final_tokens: usize,
/// Number of messages whose content was elided. Zero is the
/// hot path (nothing to do).
pub elided_messages: usize,
}
impl CompactionStats {
fn unchanged(tokens: usize) -> Self {
Self {
original_tokens: tokens,
final_tokens: tokens,
elided_messages: 0,
}
}
}
/// Approximate token count for one message. Sums the textual
/// payload's chars, divides by [`CHARS_PER_TOKEN`], and adds an
/// envelope constant. Cheap (no allocation) so safe to call once per
/// message per round.
pub fn estimate_tokens(msg: &Message) -> usize {
let chars = match &msg.content {
MessageContent::Text { text } => text.len(),
MessageContent::MultiPart { parts } => parts
.iter()
.map(|p| match p {
MessagePart::Text { text } => text.len(),
// Each image is one block in the context window; the
// upstream tokenizer handles the real cost (and it
// varies wildly by model — Qwen3-VL uses ~256-1024
// tokens per image depending on size). Take a
// middle estimate so the budget tracker doesn't
// pretend images are free.
MessagePart::Image(_) => IMAGE_TOKENS_APPROX * CHARS_PER_TOKEN as usize,
})
.sum(),
MessageContent::ToolCalls { text, calls } => {
let txt = text.as_deref().map(|s| s.len()).unwrap_or(0);
let calls_size: usize = calls
.iter()
.map(|c| c.name.len() + c.arguments.len() + c.id.len())
.sum();
txt + calls_size
}
MessageContent::ToolResult {
tool_call_id,
content,
} => tool_call_id.len() + content.len(),
};
((chars as f32 / CHARS_PER_TOKEN) as usize) + ENVELOPE_TOKENS
}
/// Sum of [`estimate_tokens`] across all messages.
pub fn total_tokens(messages: &[Message]) -> usize {
messages.iter().map(estimate_tokens).sum()
}
/// Project `messages` into a vec whose estimated token count fits in
/// `budget` tokens. Returns the projection plus stats about what
/// was done. When the input already fits, the projection is a clone
/// of the input and stats report zero elisions.
///
/// See module docs for the strategy and protected set.
pub fn compact_to_budget(messages: &[Message], budget: usize) -> (Vec<Message>, CompactionStats) {
let original = total_tokens(messages);
if original <= budget {
return (messages.to_vec(), CompactionStats::unchanged(original));
}
let mut out = messages.to_vec();
let len = out.len();
let tail_start = len.saturating_sub(KEEP_TAIL);
let mut elided = 0usize;
// Two passes. First pass: ToolResult contents (largest savings
// per elision — read_file payloads land here). Second pass: long
// Assistant prose. We don't interleave because eliding a long
// assistant turn before a really old read_file would do less
// good per elision; oldest-first ordering is enforced *within*
// each pass instead.
for pass in 0..2 {
for i in 1..tail_start {
if matches!(out[i].role, Role::User) {
continue;
}
let target_pass_2 = matches!(
&out[i].content,
MessageContent::Text { .. } | MessageContent::ToolCalls { .. }
);
let target_pass_1 = matches!(&out[i].content, MessageContent::ToolResult { .. });
let in_pass = (pass == 0 && target_pass_1) || (pass == 1 && target_pass_2);
if !in_pass {
continue;
}
if elide_in_place(&mut out[i]) {
elided += 1;
if total_tokens(&out) <= budget {
let final_tokens = total_tokens(&out);
return (
out,
CompactionStats {
original_tokens: original,
final_tokens,
elided_messages: elided,
},
);
}
}
}
}
let final_tokens = total_tokens(&out);
(
out,
CompactionStats {
original_tokens: original,
final_tokens,
elided_messages: elided,
},
)
}
/// Shrink one message's payload while keeping its structural role
/// (so tool_call_id pairing survives). Returns `true` when the
/// message changed.
///
/// - `ToolResult.content` → `(elided: N bytes of tool result)`
/// - `ToolCalls.text` → `(elided: N bytes of assistant prose)`
/// - `Text` (assistant) → `(elided: N bytes of assistant prose)`
///
/// Already-tiny payloads are skipped — eliding a 50-byte string
/// would *grow* it once the marker is in place.
fn elide_in_place(msg: &mut Message) -> bool {
match &mut msg.content {
MessageContent::ToolResult { content, .. } => {
if content.len() < ELIDE_MIN_CHARS {
return false;
}
*content = format!("(elided: {} bytes of tool result)", content.len());
true
}
MessageContent::ToolCalls { text, .. } => match text {
Some(t) if t.len() >= ELIDE_MIN_CHARS => {
*text = Some(format!("(elided: {} bytes of assistant prose)", t.len()));
true
}
_ => false,
},
MessageContent::Text { text } => {
if text.len() < ELIDE_MIN_CHARS {
return false;
}
*text = format!("(elided: {} bytes of assistant prose)", text.len());
true
}
MessageContent::MultiPart { .. } => {
// MultiPart messages today only exist as User turns,
// and User turns are protected by the role check in
// `compact_to_budget` — so this branch is unreachable
// for current call sites. Returning false keeps the
// unreachable path benign if a future stage starts
// emitting MultiPart on other roles.
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::ToolCall;
fn sys(text: &str) -> Message {
Message {
role: Role::System,
content: MessageContent::Text { text: text.into() },
}
}
fn user(text: &str) -> Message {
Message {
role: Role::User,
content: MessageContent::Text { text: text.into() },
}
}
fn assistant_text(text: &str) -> Message {
Message {
role: Role::Assistant,
content: MessageContent::Text { text: text.into() },
}
}
fn assistant_calls(text: Option<&str>, name: &str, args: &str, id: &str) -> Message {
Message {
role: Role::Assistant,
content: MessageContent::ToolCalls {
text: text.map(|s| s.to_string()),
calls: vec![ToolCall {
id: id.into(),
name: name.into(),
arguments: args.into(),
}],
},
}
}
fn tool_result(id: &str, body: &str) -> Message {
Message {
role: Role::Tool,
content: MessageContent::ToolResult {
tool_call_id: id.into(),
content: body.into(),
},
}
}
#[test]
fn under_budget_is_a_no_op_clone() {
let msgs = vec![sys("you are an agent"), user("hi"), assistant_text("hello")];
let (out, stats) = compact_to_budget(&msgs, 10_000);
assert_eq!(stats.elided_messages, 0);
assert_eq!(stats.original_tokens, stats.final_tokens);
assert_eq!(out.len(), msgs.len());
// Strings unchanged.
match &out[2].content {
MessageContent::Text { text } => assert_eq!(text, "hello"),
other => panic!("expected Text, got {other:?}"),
}
}
#[test]
fn elides_old_tool_result_before_old_assistant_prose() {
// History: sys, user, assistant_calls, big_tool_result,
// assistant_with_big_text, user, assistant_calls,
// small_tool_result.
// KEEP_TAIL=4 protects the last four; the big tool result
// sits in the prunable range and should go first because
// pass 0 (tool results) runs before pass 1 (prose).
let big_result = "X".repeat(4096);
let big_prose = "Y".repeat(2048);
let msgs = vec![
sys("preamble"),
user("first ask"),
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "c0"),
tool_result("c0", &big_result),
assistant_text(&big_prose),
user("follow up"),
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "c1"),
tool_result("c1", "short result body"),
];
let before = total_tokens(&msgs);
// Force compaction by setting budget well below current.
let budget = before / 2;
let (out, stats) = compact_to_budget(&msgs, budget);
assert!(
stats.elided_messages >= 1,
"expected at least one elision, got {stats:?}"
);
// The big tool result must be elided (oldest fat target).
match &out[3].content {
MessageContent::ToolResult { content, .. } => {
assert!(
content.starts_with("(elided:"),
"tool result not elided: {content:?}"
);
}
other => panic!("expected ToolResult, got {other:?}"),
}
// Last four messages must be untouched.
assert!(matches!(
&out[out.len() - 1].content,
MessageContent::ToolResult { content, .. } if content == "short result body"
));
}
#[test]
fn never_elides_system_or_user_turns() {
let big_user = "U".repeat(8192);
let msgs = vec![sys("preamble"), user(&big_user), assistant_text("ok")];
let budget = 10; // way below — forces all possible elision
let (out, _stats) = compact_to_budget(&msgs, budget);
// System unchanged.
match &out[0].content {
MessageContent::Text { text } => assert_eq!(text, "preamble"),
other => panic!("expected Text, got {other:?}"),
}
// User unchanged even though it's huge.
match &out[1].content {
MessageContent::Text { text } => assert_eq!(text.len(), big_user.len()),
other => panic!("expected Text, got {other:?}"),
}
}
#[test]
fn preserves_tool_call_id_pairing_after_elision() {
// OpenAI strict mode rejects a tool-result whose tool_call_id
// doesn't match a preceding assistant tool_call. Elision
// must not break that linkage.
let big = "Z".repeat(4096);
let msgs = vec![
sys("preamble"),
user("first"),
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "call_42"),
tool_result("call_42", &big),
// Tail messages.
user("next"),
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "call_43"),
tool_result("call_43", "ok"),
assistant_text("done"),
];
let budget = total_tokens(&msgs) / 3;
let (out, _stats) = compact_to_budget(&msgs, budget);
// The assistant call and its result both carry call_42.
let call_id = match &out[2].content {
MessageContent::ToolCalls { calls, .. } => calls[0].id.clone(),
other => panic!("expected ToolCalls, got {other:?}"),
};
match &out[3].content {
MessageContent::ToolResult { tool_call_id, .. } => {
assert_eq!(tool_call_id, &call_id, "pairing broken");
}
other => panic!("expected ToolResult, got {other:?}"),
}
}
#[test]
fn estimate_tokens_grows_with_content() {
let small = sys("hi");
let large = sys(&"x".repeat(10_000));
assert!(estimate_tokens(&large) > estimate_tokens(&small) * 100);
}
#[test]
fn elide_in_place_skips_short_content() {
let mut m = tool_result("c0", "tiny");
assert!(!elide_in_place(&mut m));
match m.content {
MessageContent::ToolResult { content, .. } => assert_eq!(content, "tiny"),
other => panic!("expected ToolResult, got {other:?}"),
}
}
#[test]
fn returns_best_effort_when_budget_unmeetable() {
// Single huge user message that cannot be elided. Budget 10.
// We don't error — we return what we have and let upstream
// refuse the prompt with its own error.
let big_user = "U".repeat(100_000);
let msgs = vec![sys("preamble"), user(&big_user)];
let (out, stats) = compact_to_budget(&msgs, 10);
assert_eq!(out.len(), msgs.len());
assert!(stats.final_tokens > 10, "still over budget by design");
}
}

View File

@@ -1,424 +0,0 @@
//! Configuration for the helexa-acp bridge.
//!
//! Loaded from `$XDG_CONFIG_HOME/helexa-acp/config.toml` (or
//! `~/.config/helexa-acp/config.toml` as a fallback). If no config file
//! exists, falls back to building a single anonymous endpoint from env
//! vars — that keeps "just point at one cortex" frictionless without
//! requiring a config file on disk.
//!
//! The design goal is "the missing ACP binary for users with multiple
//! API endpoints (possibly on a private LAN, possibly mixing wire
//! types)". Hence: every endpoint is named, has its own wire API, and
//! has its own default model. The agent's selected model id can be
//! prefixed `endpoint:model` to route across endpoints; a bare
//! `model` falls through to the configured `default_endpoint`.
//!
//! ### Example TOML
//!
//! ```toml
//! default_endpoint = "helexa"
//!
//! [[endpoints]]
//! name = "helexa"
//! base_url = "http://hanzalova.internal:31313/v1"
//! wire_api = "openai-chat"
//! default_model = "helexa/large"
//!
//! [[endpoints]]
//! name = "openrouter"
//! base_url = "https://openrouter.ai/api/v1"
//! wire_api = "openai-chat"
//! api_key_env = "OPENROUTER_API_KEY"
//! default_model = "anthropic/claude-opus-4"
//!
//! [[endpoints]]
//! name = "lmstudio"
//! base_url = "http://localhost:1234/v1"
//! wire_api = "openai-chat"
//! default_model = "auto"
//! ```
use anyhow::{Context, anyhow};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use url::Url;
const DEFAULT_BASE_URL: &str = "http://hanzalova.internal:31313/v1";
const DEFAULT_MODEL: &str = "helexa/large";
const DEFAULT_ENDPOINT_NAME: &str = "default";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
/// Name of the endpoint used when a request doesn't pick one
/// explicitly. Must reference an entry in `endpoints`. Defaults to
/// the first endpoint declared if unset.
#[serde(default)]
pub default_endpoint: Option<String>,
/// Per-endpoint configuration. At least one entry is required.
#[serde(default)]
pub endpoints: Vec<EndpointConfig>,
/// Optional path to a system-prompt file. When unset, the built-in
/// default prompt from `prompt.rs` is used.
#[serde(default)]
pub system_prompt_path: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndpointConfig {
/// Short identifier used in `endpoint:model` routing and in logs.
pub name: String,
/// Base URL of the OpenAI-compatible API. Must include the `/v1`
/// (or equivalent) suffix — paths like `chat/completions` and
/// `models` are joined onto this.
pub base_url: Url,
/// Wire protocol the endpoint speaks. Phase 1 supports
/// [`WireApi::OpenAiChat`] only; `openai-responses` and
/// `anthropic-messages` land later behind their own providers.
#[serde(default)]
pub wire_api: WireApi,
/// Model to use when the client hasn't picked one via
/// `session/set_model`.
#[serde(default)]
pub default_model: Option<String>,
/// Static API key to send as `Authorization: Bearer …`. Prefer
/// `api_key_env` for anything sensitive — keys in plain TOML are a
/// liability.
#[serde(default)]
pub api_key: Option<String>,
/// Env var name to read for the API key. Resolved at startup so a
/// missing env var yields a clear error rather than silent
/// unauthenticated calls.
#[serde(default)]
pub api_key_env: Option<String>,
/// Cap on the model's output tokens per turn. `None` lets the
/// upstream pick its own default (cortex/neuron's default is
/// often small enough to trip Zed's "Output Limit Reached" on
/// long responses). Set to e.g. `32768` to let the model
/// produce longer turns. Goes into the OpenAI `max_tokens`
/// request field.
#[serde(default)]
pub max_tokens: Option<u64>,
/// Model context window in tokens (prompt + response). When set,
/// the agent compacts conversation history before each completion
/// so the prompt fits within `context_window - max_tokens - safety`
/// tokens — long sessions on small-context local models (Qwen3 at
/// 32 K) survive past the first few tool-call rounds rather than
/// dying with `prompt_too_long`. `None` disables compaction.
#[serde(default)]
pub context_window: Option<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum WireApi {
/// `POST {base}/chat/completions` returning OpenAI-format SSE.
/// Compatible with cortex, LM Studio, Ollama (compat mode),
/// OpenRouter, OpenAI itself.
#[default]
#[serde(rename = "openai-chat")]
OpenAiChat,
/// `POST {base}/responses` — OpenAI's newer Responses API. Not
/// implemented yet; the variant is reserved so endpoint configs
/// can be authored ahead of provider support.
#[serde(rename = "openai-responses")]
OpenAiResponses,
/// `POST {base}/messages` — Anthropic format. Reserved.
#[serde(rename = "anthropic-messages")]
AnthropicMessages,
}
impl EndpointConfig {
/// Resolve the API key from `api_key` (literal) or `api_key_env`
/// (env-var lookup). Returns `Ok(None)` when neither is set;
/// `Err` when `api_key_env` references a missing variable.
pub fn resolve_api_key(&self) -> anyhow::Result<Option<String>> {
if let Some(literal) = &self.api_key {
return Ok(Some(literal.clone()));
}
if let Some(var) = &self.api_key_env {
return Ok(Some(std::env::var(var).with_context(|| {
format!(
"endpoint '{}' references missing env var {}",
self.name, var
)
})?));
}
Ok(None)
}
/// `{base_url}/chat/completions`.
pub fn chat_completions_url(&self) -> Url {
join_segments(&self.base_url, &["chat", "completions"])
}
/// `{base_url}/responses` — OpenAI Responses API endpoint.
pub fn responses_url(&self) -> Url {
join_segments(&self.base_url, &["responses"])
}
/// `{base_url}/models`. Called from `Provider::list_models`, which
/// Stage 4 wires into the model-picker dropdown; until then it's
/// reachable code with no in-tree callers.
#[allow(dead_code)]
pub fn models_url(&self) -> Url {
join_segments(&self.base_url, &["models"])
}
}
impl Config {
/// Load from TOML at the standard config path, or build from env
/// vars if no file exists. Env-fallback yields a single endpoint
/// named `"default"`.
pub fn load() -> anyhow::Result<Self> {
let path = config_path();
if let Some(path) = &path
&& path.exists()
{
return Self::from_file(path);
}
Self::from_env()
}
/// Single-endpoint config constructed from `HELEXA_ACP_BASE_URL`,
/// `HELEXA_ACP_MODEL`, `HELEXA_ACP_API_KEY`,
/// `HELEXA_ACP_SYSTEM_PROMPT_PATH`, `HELEXA_ACP_MAX_TOKENS`.
pub fn from_env() -> anyhow::Result<Self> {
let base_url = std::env::var("HELEXA_ACP_BASE_URL")
.ok()
.unwrap_or_else(|| DEFAULT_BASE_URL.into());
let base_url = Url::parse(&base_url)
.with_context(|| format!("HELEXA_ACP_BASE_URL is not a valid URL ({base_url})"))?;
let default_model = std::env::var("HELEXA_ACP_MODEL")
.ok()
.unwrap_or_else(|| DEFAULT_MODEL.into());
let api_key = std::env::var("HELEXA_ACP_API_KEY")
.ok()
.filter(|s| !s.is_empty());
let system_prompt_path = std::env::var("HELEXA_ACP_SYSTEM_PROMPT_PATH")
.ok()
.filter(|s| !s.is_empty())
.map(PathBuf::from);
let max_tokens = std::env::var("HELEXA_ACP_MAX_TOKENS")
.ok()
.filter(|s| !s.is_empty())
.map(|s| {
s.parse::<u64>().with_context(|| {
format!("HELEXA_ACP_MAX_TOKENS is not a positive integer ({s})")
})
})
.transpose()?;
let context_window = std::env::var("HELEXA_ACP_CONTEXT_WINDOW")
.ok()
.filter(|s| !s.is_empty())
.map(|s| {
s.parse::<usize>().with_context(|| {
format!("HELEXA_ACP_CONTEXT_WINDOW is not a positive integer ({s})")
})
})
.transpose()?;
Ok(Self {
default_endpoint: Some(DEFAULT_ENDPOINT_NAME.into()),
endpoints: vec![EndpointConfig {
name: DEFAULT_ENDPOINT_NAME.into(),
base_url,
wire_api: WireApi::OpenAiChat,
default_model: Some(default_model),
api_key,
api_key_env: None,
max_tokens,
context_window,
}],
system_prompt_path,
})
}
pub fn from_file(path: &Path) -> anyhow::Result<Self> {
let text = std::fs::read_to_string(path)
.with_context(|| format!("read config {}", path.display()))?;
let mut cfg: Self =
toml::from_str(&text).with_context(|| format!("parse config {}", path.display()))?;
cfg.validate()?;
Ok(cfg)
}
fn validate(&mut self) -> anyhow::Result<()> {
if self.endpoints.is_empty() {
return Err(anyhow!("config has no [[endpoints]] entries"));
}
for (i, ep) in self.endpoints.iter().enumerate() {
if ep.name.is_empty() {
return Err(anyhow!("endpoints[{i}] has empty name"));
}
if ep.name.contains(':') {
return Err(anyhow!(
"endpoints[{i}].name '{}' contains ':' which would clash \
with the endpoint:model selector syntax",
ep.name
));
}
}
// Pick a default endpoint if none was named.
if self.default_endpoint.is_none() {
self.default_endpoint = Some(self.endpoints[0].name.clone());
}
let default_name = self.default_endpoint.as_deref().unwrap();
if !self.endpoints.iter().any(|e| e.name == default_name) {
return Err(anyhow!(
"default_endpoint '{default_name}' is not declared in [[endpoints]]"
));
}
Ok(())
}
/// Look up an endpoint by name. Returns `None` if not configured.
pub fn endpoint(&self, name: &str) -> Option<&EndpointConfig> {
self.endpoints.iter().find(|e| e.name == name)
}
/// The default endpoint (guaranteed to exist after `validate`).
pub fn default_endpoint(&self) -> &EndpointConfig {
let name = self
.default_endpoint
.as_deref()
.expect("default_endpoint set by validate");
self.endpoint(name)
.expect("default_endpoint resolves after validate")
}
}
/// Parse an ACP-side `model` field into (endpoint name, raw model id).
///
/// `helexa:helexa/large` → (`Some("helexa")`, `"helexa/large"`).
/// `helexa/large` → (`None`, `"helexa/large"`).
///
/// The split happens at the FIRST colon. Model ids commonly contain
/// `/` (HuggingFace style) but rarely `:`; if a model id ever does, the
/// user can quote-prefix with the default endpoint name.
pub fn parse_model_selector(input: &str) -> (Option<&str>, &str) {
match input.split_once(':') {
Some((endpoint, model)) if !endpoint.is_empty() && !model.is_empty() => {
(Some(endpoint), model)
}
_ => (None, input),
}
}
fn config_path() -> Option<PathBuf> {
if let Ok(override_path) = std::env::var("HELEXA_ACP_CONFIG_PATH") {
return Some(PathBuf::from(override_path));
}
let xdg = std::env::var("XDG_CONFIG_HOME")
.ok()
.filter(|s| !s.is_empty());
let base = xdg.map(PathBuf::from).or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| PathBuf::from(h).join(".config"))
})?;
Some(base.join("helexa-acp").join("config.toml"))
}
fn join_segments(base: &Url, segments: &[&str]) -> Url {
let mut out = base.clone();
if let Ok(mut path) = out.path_segments_mut() {
path.pop_if_empty().extend(segments.iter().copied());
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn url_join_handles_trailing_slash() {
let ep = EndpointConfig {
name: "x".into(),
base_url: Url::parse("http://h.internal:31313/v1").unwrap(),
wire_api: WireApi::OpenAiChat,
default_model: None,
api_key: None,
api_key_env: None,
max_tokens: None,
context_window: None,
};
assert_eq!(
ep.chat_completions_url().as_str(),
"http://h.internal:31313/v1/chat/completions"
);
assert_eq!(
ep.models_url().as_str(),
"http://h.internal:31313/v1/models"
);
}
#[test]
fn parses_model_selector() {
assert_eq!(
parse_model_selector("helexa:helexa/large"),
(Some("helexa"), "helexa/large")
);
assert_eq!(parse_model_selector("helexa/large"), (None, "helexa/large"));
assert_eq!(parse_model_selector("gpt-5"), (None, "gpt-5"));
// Edge case: a leading colon → no endpoint.
assert_eq!(parse_model_selector(":gpt-5"), (None, ":gpt-5"));
}
#[test]
fn env_fallback_builds_single_endpoint() {
// Don't actually set env vars (would race with other tests);
// just confirm the default path constructs cleanly.
unsafe {
std::env::remove_var("HELEXA_ACP_BASE_URL");
std::env::remove_var("HELEXA_ACP_MODEL");
std::env::remove_var("HELEXA_ACP_API_KEY");
}
let cfg = Config::from_env().unwrap();
assert_eq!(cfg.endpoints.len(), 1);
assert_eq!(cfg.endpoints[0].name, "default");
assert_eq!(cfg.endpoints[0].base_url.as_str(), DEFAULT_BASE_URL);
assert_eq!(
cfg.endpoints[0].default_model.as_deref(),
Some(DEFAULT_MODEL)
);
}
#[test]
fn toml_parses_multi_endpoint() {
let toml_text = r#"
default_endpoint = "helexa"
[[endpoints]]
name = "helexa"
base_url = "http://hanzalova.internal:31313/v1"
default_model = "helexa/large"
[[endpoints]]
name = "openrouter"
base_url = "https://openrouter.ai/api/v1"
wire_api = "openai-chat"
api_key_env = "OPENROUTER_API_KEY"
default_model = "anthropic/claude-opus-4"
"#;
let mut cfg: Config = toml::from_str(toml_text).unwrap();
cfg.validate().unwrap();
assert_eq!(cfg.endpoints.len(), 2);
assert_eq!(cfg.default_endpoint().name, "helexa");
assert_eq!(cfg.endpoints[0].wire_api, WireApi::OpenAiChat);
assert_eq!(
cfg.endpoints[1].api_key_env.as_deref(),
Some("OPENROUTER_API_KEY")
);
}
#[test]
fn validate_rejects_colon_in_endpoint_name() {
let toml_text = r#"
[[endpoints]]
name = "bad:name"
base_url = "http://x/v1"
"#;
let mut cfg: Config = toml::from_str(toml_text).unwrap();
let err = cfg.validate().unwrap_err();
assert!(format!("{err}").contains("clash"));
}
}

View File

@@ -1,145 +0,0 @@
//! helexa-acp — Agent Client Protocol bridge for multi-endpoint LLM
//! setups (helexa, LM Studio, Ollama, OpenRouter, OpenAI, Anthropic,
//! …) with a clean per-endpoint wire-format selector.
//!
//! Speaks ACP over stdio to an editor client (Zed today). Every
//! configured endpoint produces a wire-format-specific
//! [`provider::Provider`] implementation; the agent loop in
//! [`agent::Agent`] is provider-agnostic, so adding e.g. an Anthropic
//! /v1/messages provider doesn't touch `agent.rs`.
//!
//! Config: `$XDG_CONFIG_HOME/helexa-acp/config.toml` for the multi-
//! endpoint case; env vars (`HELEXA_ACP_BASE_URL`, etc.) for the
//! single-endpoint case when no config file exists.
use agent_client_protocol::{Result, Stdio};
use std::sync::Arc;
mod agent;
mod compaction;
mod config;
mod path_util;
mod prompt;
mod provider;
mod qwen3;
mod session;
mod store;
mod tool_runner;
mod tools;
use agent::Agent;
use config::{Config, EndpointConfig, WireApi};
use provider::{
Provider, anthropic_messages::AnthropicMessagesProvider, openai_chat::OpenAIChatProvider,
openai_responses::OpenAIResponsesProvider,
};
/// Set up tracing. Logs go to stderr by default — stdout is
/// reserved for the JSON-RPC stream. Setting `HELEXA_ACP_LOG_FILE`
/// to an absolute path appends logs to that file instead, which is
/// the practical way to capture debug output when the agent runs
/// under an editor (Zed, etc.) that doesn't surface stderr.
///
/// `RUST_LOG` still controls levels (e.g. `helexa_acp=debug`).
/// ANSI colours are auto-stripped when writing to a file so the log
/// is plain text.
fn init_tracing() {
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"));
let log_file = std::env::var("HELEXA_ACP_LOG_FILE")
.ok()
.filter(|s| !s.is_empty());
match log_file {
Some(path) => match std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)
{
Ok(file) => {
tracing_subscriber::fmt()
.with_writer(std::sync::Mutex::new(file))
.with_env_filter(env_filter)
.with_ansi(false)
.init();
}
Err(e) => {
// Fall back to stderr and shout. We don't want a
// typo'd log path to silence the agent entirely.
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(env_filter)
.init();
tracing::warn!(
path = %path,
error = %e,
"HELEXA_ACP_LOG_FILE could not be opened; using stderr"
);
}
},
None => {
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(env_filter)
.init();
}
}
}
/// Build a provider for `endpoint` according to its declared
/// `wire_api`. Future wire types (OpenAI Responses, Anthropic
/// /v1/messages, Ollama native) slot in here without changing the
/// caller.
fn build_provider(endpoint: EndpointConfig) -> anyhow::Result<Arc<dyn Provider>> {
match endpoint.wire_api {
WireApi::OpenAiChat => Ok(Arc::new(OpenAIChatProvider::new(endpoint)?)),
WireApi::OpenAiResponses => Ok(Arc::new(OpenAIResponsesProvider::new(endpoint)?)),
WireApi::AnthropicMessages => Ok(Arc::new(AnthropicMessagesProvider::new(endpoint)?)),
}
}
#[tokio::main]
async fn main() -> Result<()> {
init_tracing();
let cfg = Config::load()
.map_err(|e| agent_client_protocol::util::internal_error(format!("config: {e:#}")))?;
tracing::info!(
endpoints = cfg.endpoints.len(),
default_endpoint = %cfg.default_endpoint().name,
default_model = ?cfg.default_endpoint().default_model,
"helexa-acp starting"
);
// Build a provider for each configured endpoint up-front. Cheap —
// just sets up a reqwest::Client and resolves the API key — and
// surfaces config mistakes (missing API key env var, unsupported
// wire_api) before the editor even sends an initialize request.
let mut providers: Vec<Arc<dyn Provider>> = Vec::with_capacity(cfg.endpoints.len());
for endpoint in &cfg.endpoints {
match build_provider(endpoint.clone()) {
Ok(p) => {
tracing::info!(
endpoint = %endpoint.name,
base_url = %endpoint.base_url,
wire_api = ?endpoint.wire_api,
"registered provider"
);
providers.push(p);
}
Err(e) => {
tracing::warn!(
endpoint = %endpoint.name,
error = %format!("{e:#}"),
"skipping endpoint with invalid config"
);
}
}
}
let agent = Agent::new(&cfg, providers)
.await
.map_err(|e| agent_client_protocol::util::internal_error(format!("agent: {e:#}")))?;
agent.serve(Stdio::new()).await
}

View File

@@ -1,192 +0,0 @@
//! Path expansion shared across every tool that takes a path.
//!
//! Models often emit shell-style paths like `~/git/repo/file.rs` or
//! `$HOME/notes.md`. ACP's `fs/read_text_file` and friends — and our
//! own local `std::fs` reads — both want a real absolute path; the
//! `~` / `$HOME` forms reach them as literal strings and the open
//! fails. The tool schemas already document "absolute path" but in
//! practice the model slips up often enough that handling it
//! server-side is the difference between "works" and "the agent is
//! brittle".
//!
//! Scope is deliberately small:
//!
//! - `~` and `~/` (current user only — `~user` lookups would require
//! pulling in passwd parsing).
//! - `$HOME` and `$HOME/`.
//!
//! Any other shell variable (`$PWD`, `${HOME}`, …) passes through
//! unchanged. The shell already expands them inside `bash` tool
//! commands; for the file-tool argument fields, we deliberately
//! limit the set so the behaviour is predictable.
//!
//! Falls back to the input path verbatim when `HOME` is unset
//! (stripped-down container env). That preserves the "no surprise
//! mutations" rule — never invent a path the caller didn't ask for.
use std::path::{Path, PathBuf};
/// Process-global lock for tests that mutate `HOME`. Anyone in the
/// crate touching `HOME` must hold this for the duration of the
/// read-modify-restore window — otherwise concurrent `cargo test`
/// workers race and flake.
///
/// Only built into the test binaries. Production code never mutates
/// env vars.
#[cfg(test)]
pub(crate) static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
/// Expand `~`, `~/`, `$HOME`, and `$HOME/` prefixes against the
/// current user's home directory. All other inputs pass through
/// unchanged.
///
/// Returns the input verbatim if `HOME` isn't set in the env.
pub fn expand_path(input: &Path) -> PathBuf {
let Some(s) = input.to_str() else {
return input.to_path_buf();
};
let Ok(home) = std::env::var("HOME") else {
return input.to_path_buf();
};
let home = PathBuf::from(home);
if s == "~" || s == "$HOME" {
return home;
}
if let Some(rest) = s.strip_prefix("~/") {
return home.join(rest);
}
if let Some(rest) = s.strip_prefix("$HOME/") {
return home.join(rest);
}
input.to_path_buf()
}
#[cfg(test)]
mod tests {
use super::*;
/// Set HOME for the duration of the test. Tests using this run
/// serially under the crate-wide [`ENV_LOCK`] because env
/// mutation isn't thread-safe — `cargo test` parallel workers
/// would race without it.
fn with_home<F: FnOnce()>(home: &str, body: F) {
let _g = ENV_LOCK.lock().unwrap();
let prior = std::env::var("HOME").ok();
// SAFETY: tests touch process-global env. The mutex
// serialises access; sub-threads in other test modules
// touching HOME aren't expected (none in this crate).
unsafe {
std::env::set_var("HOME", home);
}
body();
unsafe {
match prior {
Some(p) => std::env::set_var("HOME", p),
None => std::env::remove_var("HOME"),
}
}
}
#[test]
fn expands_tilde_slash() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("~/git/repo/file.rs")),
PathBuf::from("/home/me/git/repo/file.rs")
);
});
}
#[test]
fn expands_bare_tilde() {
with_home("/home/me", || {
assert_eq!(expand_path(Path::new("~")), PathBuf::from("/home/me"));
});
}
#[test]
fn expands_dollar_home_slash() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("$HOME/notes.md")),
PathBuf::from("/home/me/notes.md")
);
});
}
#[test]
fn expands_bare_dollar_home() {
with_home("/home/me", || {
assert_eq!(expand_path(Path::new("$HOME")), PathBuf::from("/home/me"));
});
}
#[test]
fn absolute_path_passes_through() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("/etc/hostname")),
PathBuf::from("/etc/hostname")
);
});
}
#[test]
fn relative_path_passes_through() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("src/main.rs")),
PathBuf::from("src/main.rs")
);
});
}
#[test]
fn tilde_user_form_not_expanded() {
// ~other is shell sugar for /home/other and would require
// passwd parsing to resolve. Out of scope — pass it
// through and let the open fail with a clear error.
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("~other/x")),
PathBuf::from("~other/x")
);
});
}
#[test]
fn no_home_env_passes_through() {
// Share the same crate-wide lock as `with_home` — otherwise
// a parallel test setting HOME races this clear-and-assert
// window.
let _g = ENV_LOCK.lock().unwrap();
let prior = std::env::var("HOME").ok();
// SAFETY: serialised by LOCK above.
unsafe {
std::env::remove_var("HOME");
}
assert_eq!(
expand_path(Path::new("~/git/repo")),
PathBuf::from("~/git/repo")
);
unsafe {
if let Some(p) = prior {
std::env::set_var("HOME", p);
}
}
}
#[test]
fn dollar_other_var_not_expanded() {
with_home("/home/me", || {
assert_eq!(
expand_path(Path::new("$PWD/file")),
PathBuf::from("$PWD/file")
);
assert_eq!(
expand_path(Path::new("${HOME}/file")),
PathBuf::from("${HOME}/file")
);
});
}
}

View File

@@ -1,274 +0,0 @@
//! System prompt assembly.
//!
//! The system message has two parts:
//!
//! 1. A short human-readable preamble (working directory, style
//! instructions). Either the built-in [`DEFAULT_PROMPT`] or a
//! user-supplied file at `HELEXA_ACP_SYSTEM_PROMPT_PATH` /
//! `system_prompt_path`. `{cwd}` is substituted in both.
//! 2. A `# Tools` block in Qwen3 Hermes format (see [`crate::qwen3`])
//! describing the available functions. This is what makes the
//! model actually call them — neuron/cortex don't honour the
//! OpenAI `tools` API field, so the tool list has to live in the
//! prompt itself.
use agent_client_protocol::schema::SessionModeId;
use anyhow::Context;
use std::path::Path;
use crate::provider::ToolSpec;
use crate::qwen3;
use crate::session::MODE_PLAN;
const DEFAULT_PROMPT: &str = "\
You are helexa-acp, a coding assistant working inside an editor.
Working directory: {cwd}
Use the tools described below whenever the user's request involves
looking at or modifying files, or running commands. Do not ask the
user to paste file contents you could read yourself. All file paths
must be absolute. Writes and shell commands may prompt the user for
permission depending on the session mode.
Be concise; the user is reading your output in an editor pane.";
/// Build the system prompt for a session.
///
/// - `cwd`: session working directory (substituted for `{cwd}` in
/// the preamble — both the default and any user-supplied template).
/// - `override_path`: path to a user-supplied template, already
/// resolved by [`crate::config::Config`]. The `# Tools` block is
/// appended *after* the user's template so a custom preamble
/// still gets the tool descriptions the model needs.
/// - `tools`: the tools to advertise. Empty list → no `# Tools`
/// block is appended at all.
/// - `mode`: current session mode. When the mode is [`MODE_PLAN`]
/// a plan-mode addendum describing the restrictions and the
/// completion menu is appended *after* the `# Tools` block so it
/// is the last thing the model reads before user input.
/// - `plan_dir`: resolved plan directory for the cwd. Only consulted
/// when `mode == MODE_PLAN`. `None` means the plan directory could
/// not be resolved (no `HOME` / `XDG_DATA_HOME`) — the addendum
/// still renders but with a placeholder so the model knows to
/// surface the error to the user rather than guess a path.
pub fn build_system_prompt(
cwd: &Path,
override_path: Option<&Path>,
tools: &[ToolSpec],
mode: &SessionModeId,
plan_dir: Option<&Path>,
) -> anyhow::Result<String> {
let template = match override_path {
Some(path) => std::fs::read_to_string(path)
.with_context(|| format!("read system prompt from {}", path.display()))?,
None => DEFAULT_PROMPT.to_string(),
};
let mut prompt = template.replace("{cwd}", &cwd.display().to_string());
prompt.push_str(&qwen3::render_tool_block(tools));
if mode.0.as_ref() == MODE_PLAN {
prompt.push_str(&render_plan_mode_block(plan_dir));
}
Ok(prompt)
}
/// Plan-mode instruction block. Tells the model:
///
/// 1. Where it may write — only inside `plan_dir`.
/// 2. What it may *not* do — bash is disabled; writes outside
/// `plan_dir` are refused by the runtime.
/// 3. How to finish — emit the 3-option menu so the user can
/// switch modes and either kick off implementation (with or
/// without permission prompts) or keep iterating on the plan.
fn render_plan_mode_block(plan_dir: Option<&Path>) -> String {
let plan_path = plan_dir
.map(|p| p.display().to_string())
.unwrap_or_else(|| "<plan directory could not be resolved — tell the user>".to_string());
format!(
"\n\n# Plan mode\n\
\n\
You are in **plan mode**. Your task is to draft a written\n\
implementation plan for the user; you must NOT modify any\n\
project files or run shell commands.\n\
\n\
Rules in plan mode:\n\
\n\
- `read_file` and `list_dir` are unrestricted — use them to\n\
explore the codebase as needed.\n\
- `write_file` and `edit_file` are allowed ONLY under the\n\
plan directory: `{plan_path}`. The runtime will refuse any\n\
write outside it.\n\
- `bash` is disabled. Do not call it.\n\
\n\
Write the plan as one or more Markdown files under\n\
`{plan_path}`. Use descriptive filenames\n\
(`01-overview.md`, `02-data-model.md`, etc.). It is fine to\n\
iterate — overwrite the file when you refine a section.\n\
\n\
When the plan is complete, do NOT begin implementation.\n\
Instead, end your turn with this menu, verbatim, so the\n\
user can choose how to proceed:\n\
\n\
---\n\
**Plan complete.** To proceed, switch the session mode in\n\
the agent dropdown and send a follow-up message:\n\
\n\
1. **Bypass Permissions** — implement the plan now, skipping\n\
per-tool permission prompts.\n\
2. **Default** — implement the plan now, prompting before\n\
each write or shell command.\n\
3. **Plan** (stay here) — refine the plan; reply with the\n\
change you want and I will revise it.\n\
---\n"
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::{MODE_DEFAULT, MODE_PLAN};
use std::io::Write;
fn default_mode() -> SessionModeId {
SessionModeId::new(MODE_DEFAULT)
}
fn plan_mode() -> SessionModeId {
SessionModeId::new(MODE_PLAN)
}
#[test]
fn default_prompt_substitutes_cwd() {
let prompt =
build_system_prompt(Path::new("/home/me/proj"), None, &[], &default_mode(), None)
.unwrap();
assert!(
prompt.contains("/home/me/proj"),
"cwd not interpolated: {prompt}"
);
assert!(prompt.contains("helexa-acp"));
assert!(
!prompt.contains("{cwd}"),
"left-over placeholder in default prompt"
);
// With no tools, the # Tools block is absent.
assert!(!prompt.contains("# Tools"));
// Default mode does not get the plan-mode addendum.
assert!(!prompt.contains("# Plan mode"));
}
#[test]
fn tools_are_appended_in_hermes_format() {
let spec = ToolSpec {
name: "read_file".into(),
description: "Read a file.".into(),
parameters: serde_json::json!({"type":"object","properties":{}, "required":[]}),
};
let prompt =
build_system_prompt(Path::new("/x"), None, &[spec], &default_mode(), None).unwrap();
assert!(prompt.contains("# Tools"));
assert!(prompt.contains("<tools>"));
assert!(prompt.contains("\"name\":\"read_file\""));
assert!(prompt.contains("<tool_call>"));
}
#[test]
fn override_path_is_read_and_templated() {
let mut tmp = tempfile_in_target("prompt.txt");
tmp.write_all(b"custom prompt for {cwd} only").unwrap();
tmp.flush().unwrap();
let path = tmp.path().to_path_buf();
drop(tmp);
let prompt = build_system_prompt(
Path::new("/etc"),
Some(path.as_path()),
&[],
&default_mode(),
None,
)
.expect("read override");
assert_eq!(prompt, "custom prompt for /etc only");
let _ = std::fs::remove_file(&path);
}
#[test]
fn missing_override_path_errors() {
let err = build_system_prompt(
Path::new("/tmp"),
Some(Path::new("/definitely/not/a/real/path")),
&[],
&default_mode(),
None,
)
.unwrap_err();
assert!(format!("{err:#}").contains("read system prompt"));
}
#[test]
fn plan_mode_addendum_includes_plan_dir_and_menu() {
let plan_dir = Path::new("/home/me/.local/share/helexa-acp/plans/proj-deadbeef");
let prompt = build_system_prompt(
Path::new("/home/me/proj"),
None,
&[],
&plan_mode(),
Some(plan_dir),
)
.unwrap();
assert!(prompt.contains("# Plan mode"));
assert!(
prompt.contains(plan_dir.to_str().unwrap()),
"plan dir not interpolated: {prompt}"
);
// The 3-option menu must be present so the model emits it verbatim.
assert!(prompt.contains("Bypass Permissions"));
assert!(prompt.contains("**Default**"));
assert!(prompt.contains("3. **Plan**"));
// Bash disabled instruction must be present.
assert!(prompt.contains("`bash` is disabled"));
}
#[test]
fn plan_mode_addendum_handles_unresolved_plan_dir() {
let prompt =
build_system_prompt(Path::new("/home/me/proj"), None, &[], &plan_mode(), None).unwrap();
assert!(prompt.contains("# Plan mode"));
assert!(prompt.contains("could not be resolved"));
}
/// Tiny temp-file helper that doesn't pull in the `tempfile` crate.
/// Writes under `target/` so it's cleaned up by `cargo clean`.
fn tempfile_in_target(name: &str) -> TempHandle {
let base = std::env::var("CARGO_TARGET_TMPDIR")
.ok()
.map(std::path::PathBuf::from)
.unwrap_or_else(std::env::temp_dir);
let _ = std::fs::create_dir_all(&base);
let pid = std::process::id();
let path = base.join(format!("helexa-acp-{pid}-{name}"));
let file = std::fs::File::create(&path).expect("create temp file");
TempHandle { file, path }
}
struct TempHandle {
file: std::fs::File,
path: std::path::PathBuf,
}
impl TempHandle {
fn path(&self) -> &Path {
&self.path
}
}
impl Write for TempHandle {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.file.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.file.flush()
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,230 +0,0 @@
//! Provider trait — the seam between the ACP-side agent loop and
//! whatever wire protocol an endpoint actually speaks.
//!
//! Every concrete provider (OpenAI chat completions, OpenAI Responses,
//! Anthropic /v1/messages, Ollama native, …) implements
//! [`Provider`]. The agent constructs a [`CompletionRequest`] using
//! provider-agnostic types and consumes a stream of
//! [`CompletionEvent`]s — neither end knows which wire format is on
//! the other side of the trait.
//!
//! Day-1 provider: [`openai_chat::OpenAIChatProvider`]. Day-N
//! providers slot in without touching `agent.rs`.
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
pub mod anthropic_messages;
pub mod openai_chat;
pub mod openai_responses;
/// Provider-agnostic LLM endpoint. Implementations translate between
/// [`CompletionRequest`] / [`CompletionEvent`] and whatever wire
/// format their endpoint speaks.
#[async_trait]
pub trait Provider: Send + Sync {
/// Endpoint name as configured by the user (e.g. `"helexa"`,
/// `"openrouter"`). Used in logs and in the `endpoint:model`
/// selector.
fn name(&self) -> &str;
/// List models available at this endpoint. Used to build the
/// model-picker dropdown in editor clients (Stage 4). Should
/// return quickly (cache if necessary).
#[allow(dead_code)]
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>>;
/// Run a chat completion. Returns a stream of provider-agnostic
/// events. The stream stops when the upstream finishes, when
/// `cancel` is fired, or when the stream is dropped.
async fn complete(
&self,
request: CompletionRequest,
cancel: CancellationToken,
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>>;
}
/// One model exposed by a provider. Constructed by `list_models` —
/// Stage 4 is when the agent loop starts consuming it for the
/// model-picker dropdown.
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
/// Human-friendly name, if the endpoint exposes one. Otherwise
/// `id` is used as the display name.
#[serde(default)]
pub display_name: Option<String>,
}
/// Inputs to a completion. Provider-agnostic — concrete providers
/// translate this into their wire format.
#[derive(Debug, Clone)]
pub struct CompletionRequest {
/// Endpoint-local model id (without the `endpoint:` prefix).
pub model: String,
pub messages: Vec<Message>,
/// Tools the model is allowed to call. Empty list means no tool
/// support advertised.
pub tools: Vec<ToolSpec>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub max_tokens: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Role {
System,
User,
Assistant,
/// Tool result message. Provider impls turn this into whatever
/// shape the upstream wire format wants (OpenAI uses
/// `role: "tool"` + `tool_call_id`; Anthropic uses content blocks).
/// Stage 3 (tools) constructs this; Stage 2 never does.
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessageContent {
/// Plain text turn (system / user / assistant). Struct variant
/// rather than newtype so the persisted JSON has an explicit
/// `text` field — that lets us use internal tagging on the
/// enum, which is incompatible with newtype-of-primitive
/// variants.
Text { text: String },
/// Mixed text + image user turn. Stage 5 introduces this when
/// Zed sends an `ImageContent` block alongside the user's prompt.
/// Providers that don't support vision should down-convert by
/// dropping image parts and concatenating text parts.
MultiPart { parts: Vec<MessagePart> },
/// Assistant turn that called one or more tools. Stage 3 starts
/// constructing this when the provider stream yields a
/// `ToolCallStart` / `ToolCallArgsDelta` sequence.
ToolCalls {
/// Optional text the assistant said alongside the tool calls.
text: Option<String>,
calls: Vec<ToolCall>,
},
/// Tool result. `tool_call_id` matches the assistant's call id.
/// Stage 3 constructs this after the tool runner finishes.
ToolResult {
tool_call_id: String,
content: String,
},
}
/// One part of a [`MessageContent::MultiPart`] message.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessagePart {
Text { text: String },
Image(ImageData),
}
/// Inline image attachment. `data` is base64-encoded raw image
/// bytes; the encoder constructs an `image_url` data URI from it
/// at request time. `uri` carries any pointer the client supplied
/// (e.g. `file:///tmp/x.png`) — we keep it on the message for
/// debugging / future providers but the OpenAI encoder ignores it
/// when `data` is present (data wins, since it round-trips through
/// every wire format).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageData {
pub mime_type: String,
/// Base64-encoded image bytes (no `data:` prefix, no padding
/// stripped — exactly what `ImageContent.data` carried).
pub data: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub uri: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
/// Provider-assigned id that ties the call to its result. The
/// Qwen3 wire format we use today doesn't carry this on the
/// model side (calls and results are matched positionally inside
/// a turn), so the field looks unused in the prod build — but it
/// flows through to `MessageContent::ToolResult.tool_call_id` for
/// history bookkeeping and a future strict-OpenAI backend will
/// consume it directly.
#[allow(dead_code)]
pub id: String,
pub name: String,
/// JSON-encoded arguments. Kept as a string because providers
/// stream argument bytes incrementally and only validate at the
/// end; the agent decodes once the call is complete.
pub arguments: String,
}
#[derive(Debug, Clone)]
pub struct ToolSpec {
pub name: String,
pub description: String,
/// JSON Schema of the arguments object.
pub parameters: Value,
}
/// Events emitted by a provider during a streaming completion.
#[derive(Debug, Clone)]
pub enum CompletionEvent {
/// Incremental visible text from the assistant.
TextDelta(String),
/// Incremental "reasoning" / thought text, if the model emits one
/// (e.g. Qwen3 with `<think>` tags surfaced as a separate stream,
/// or OpenAI reasoning models).
ReasoningDelta(String),
/// A new tool call has started. Stage 2 ignores the payload; the
/// agent loop in Stage 3 reads `index` to correlate with
/// [`Self::ToolCallArgsDelta`], `id` for the eventual tool-result
/// turn, and `name` to dispatch the runner.
#[allow(dead_code)]
ToolCallStart {
index: usize,
id: String,
name: String,
},
/// More argument bytes for a tool call already announced via
/// [`Self::ToolCallStart`]. Stage 2 ignores; Stage 3 accumulates
/// the bytes by `index` until the call's arguments are complete.
#[allow(dead_code)]
ToolCallArgsDelta { index: usize, args_delta: String },
/// A `<tool_call>` block whose JSON couldn't be parsed even with
/// the qwen3 module's repair attempts. The agent surfaces this
/// as a Failed `SessionUpdate::ToolCall` card with the raw body
/// visible (so the editor renders structured failure UI rather
/// than dumping the body inline in the message pane), and feeds
/// a synthetic tool-error message back into history so the
/// model can self-correct on the next round.
MalformedToolCall { raw: String },
/// Stream finished. Carries the upstream `finish_reason` if it
/// gave one (`"stop"`, `"length"`, `"tool_calls"`, …).
Finish { reason: Option<String> },
/// Final usage stats, if the provider supplied them. Stage 2
/// matches the variant to drop it; Stage 6b (token metrics) is
/// when the payload starts being read.
#[allow(dead_code)]
Usage(UsageStats),
}
/// Token accounting reported by the provider at the end of a stream.
/// Stage 2 doesn't surface usage anywhere — the stable `PromptResponse`
/// has no usage field, and the unstable variant is gated. Stage 6b
/// turns these on with Prometheus metrics.
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, Default)]
pub struct UsageStats {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,987 +0,0 @@
//! OpenAI Responses API (`POST /v1/responses`) provider.
//!
//! Mirror image of [`super::openai_chat`]: same `Provider` trait
//! impl, same back-pressured SSE decoder, but speaking OpenAI's
//! newer Responses surface instead of chat completions.
//!
//! Differences from the chat provider, all contained in this file:
//!
//! - **Request encoding**: history flattens into an `input` array
//! of typed items (`message`, `function_call`, `function_call_output`)
//! plus a top-level `instructions` field for the system prompt.
//! Multi-part user content stays in the same `[{type:"input_text"},
//! {type:"input_image"}]` shape neuron's `request_to_chat` already
//! accepts.
//! - **Streaming decoder**: events are named (`response.created`,
//! `response.output_text.delta`, `response.completed`, …) carried
//! on the SSE `event:` line. The chat path's `[DONE]` terminator
//! doesn't apply; the stream ends after `response.completed`.
//! - **Tool calls** plumb through the `response.output_item.added`
//! (item type `function_call`) → `response.function_call_arguments.delta`
//! → `response.function_call_arguments.done` event sequence. The
//! neuron candle harness doesn't synthesize these yet (tracked as
//! issue #6), but the decoder is wired so the day the upstream
//! does, downstream `CompletionEvent::ToolCall*` plumbing just
//! works.
//!
//! Tool-name handling: the model knows its tool descriptions via
//! the [`crate::qwen3`] system-prompt block exactly the way the chat
//! provider does. We don't echo them in the request body because
//! neuron currently ignores `tools` on /v1/responses (same as on
//! /v1/chat/completions). Once neuron honours request-side tool
//! definitions, both providers add them in the same place.
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures::{Stream, StreamExt, stream::BoxStream};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::HashMap;
use tokio_util::sync::CancellationToken;
use super::{
CompletionEvent, CompletionRequest, Message, MessageContent, MessagePart, ModelInfo, Provider,
Role, UsageStats,
};
use crate::config::EndpointConfig;
pub struct OpenAIResponsesProvider {
endpoint: EndpointConfig,
#[allow(dead_code)] // Read in `complete()`'s HTTP path; tests don't stand up a server.
api_key: Option<String>,
#[allow(dead_code)]
http: reqwest::Client,
}
impl OpenAIResponsesProvider {
pub fn new(endpoint: EndpointConfig) -> anyhow::Result<Self> {
let api_key = endpoint.resolve_api_key()?;
let http = reqwest::Client::builder()
// Same generous timeout as the chat provider: cortex may
// need to cold-load a model before serving the first
// chunk, which can be tens of seconds. Cancellation
// handles early termination, not timeout.
.timeout(std::time::Duration::from_secs(600))
.build()?;
Ok(Self {
endpoint,
api_key,
http,
})
}
}
#[async_trait]
impl Provider for OpenAIResponsesProvider {
fn name(&self) -> &str {
&self.endpoint.name
}
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>> {
let mut req = self.http.get(self.endpoint.models_url());
if let Some(key) = &self.api_key {
req = req.bearer_auth(key);
}
let resp = req
.send()
.await
.map_err(|e| anyhow::anyhow!("{} list_models: {e}", self.endpoint.name))?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!(
"{} list_models returned {}: {}",
self.endpoint.name,
status,
body
);
}
let body: WireModelsResponse = resp.json().await?;
Ok(body
.data
.into_iter()
.map(|m| ModelInfo {
id: m.id,
display_name: None,
})
.collect())
}
async fn complete(
&self,
request: CompletionRequest,
cancel: CancellationToken,
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>> {
let body = encode_request(&request);
tracing::debug!(
endpoint = %self.endpoint.name,
url = %self.endpoint.responses_url(),
body = %serde_json::to_string(&body).unwrap_or_else(|_| "<unserializable>".into()),
"POST /responses"
);
let mut req = self.http.post(self.endpoint.responses_url()).json(&body);
if let Some(key) = &self.api_key {
req = req.bearer_auth(key);
}
let resp = req
.send()
.await
.map_err(|e| anyhow::anyhow!("{} responses send: {e}", self.endpoint.name))?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!(
"{} responses returned {}: {}",
self.endpoint.name,
status,
body
);
}
let sse = resp.bytes_stream().eventsource();
let stream = decode_stream(sse, cancel);
Ok(Box::pin(stream))
}
}
// ── Request encoding ─────────────────────────────────────────────────
fn encode_request(req: &CompletionRequest) -> Value {
// Pull the system messages out of history into a single
// `instructions` string — the Responses API expects them there,
// not inline as an `input` item. Multiple system messages
// concatenate with blank lines so we don't lose ordering.
let mut instructions: Vec<String> = Vec::new();
let mut input_items: Vec<Value> = Vec::new();
for msg in &req.messages {
if msg.role == Role::System
&& let MessageContent::Text { text } = &msg.content
{
instructions.push(text.clone());
continue;
}
if let Some(item) = encode_message_as_input_item(msg) {
input_items.push(item);
}
}
let mut body = json!({
"model": req.model,
"input": input_items,
"stream": true,
});
if let Value::Object(map) = &mut body {
if !instructions.is_empty() {
map.insert(
"instructions".into(),
Value::String(instructions.join("\n\n")),
);
}
if let Some(t) = req.temperature {
map.insert("temperature".into(), json!(t));
}
if let Some(p) = req.top_p {
map.insert("top_p".into(), json!(p));
}
if let Some(m) = req.max_tokens {
// Responses calls it `max_output_tokens`; preserve the
// semantic (response cap) when we translate.
map.insert("max_output_tokens".into(), json!(m));
}
}
body
}
fn encode_message_as_input_item(msg: &Message) -> Option<Value> {
match (msg.role, &msg.content) {
(Role::System, _) => None, // handled out-of-band as `instructions`
(Role::User, MessageContent::Text { text }) => Some(json!({
"type": "message",
"role": "user",
"content": text,
})),
(Role::User, MessageContent::MultiPart { parts }) => Some(json!({
"type": "message",
"role": "user",
"content": encode_user_parts(parts),
})),
(Role::Assistant, MessageContent::Text { text }) => Some(json!({
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": text,
"annotations": [],
}],
})),
(Role::Assistant, MessageContent::ToolCalls { text, calls }) => {
// Assistant turns that called tools become a sequence of
// items: an optional `message` (any prose alongside the
// call) followed by one `function_call` per call. Mirrors
// OpenAI Responses' "each item is one structural slot"
// shape.
//
// We can't return multiple items from one call site, so
// we encode this by side-stuffing additional items into a
// single composite value and have the caller flatten —
// but that complicates the API. Easier: build the array
// ourselves in the caller path. For now, emit just the
// function_calls (the assistant's prose lives in the next
// turn's chat history anyway because the model isn't
// looking back at its own previous narration). If the
// text is non-empty AND we have calls, we lose the text;
// qwen3 rarely emits prose alongside tool calls so this
// is a deliberate simplification — revisit if it bites.
let _ = text;
// Take the first call only for the moment; multi-call
// turns would need the caller-flattening above.
let call = calls.first()?;
Some(json!({
"type": "function_call",
"call_id": call.id,
"name": call.name,
"arguments": call.arguments,
}))
}
(
Role::Tool,
MessageContent::ToolResult {
tool_call_id,
content,
},
) => Some(json!({
"type": "function_call_output",
"call_id": tool_call_id,
"output": content,
})),
(role, content) => {
tracing::warn!(
?role,
?content,
"openai_responses: unexpected (role, content) shape"
);
None
}
}
}
fn encode_user_parts(parts: &[MessagePart]) -> Value {
let items: Vec<Value> = parts
.iter()
.map(|p| match p {
MessagePart::Text { text } => json!({"type": "input_text", "text": text}),
MessagePart::Image(img) => json!({
"type": "input_image",
"image_url": format!("data:{};base64,{}", img.mime_type, img.data),
}),
})
.collect();
Value::Array(items)
}
// ── Wire types ──────────────────────────────────────────────────────
#[allow(dead_code)] // fields read only when list_models runs against a real endpoint
#[derive(Debug, Deserialize)]
struct WireModelsResponse {
data: Vec<WireModelObject>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct WireModelObject {
id: String,
}
// SSE event payload shapes. We only model the fields we care about;
// `#[serde(default)]` + `Option` everywhere else lets the upstream
// add optional fields without breaking deserialise.
#[derive(Debug, Deserialize, Serialize)]
struct OutputItemAddedEvent {
#[serde(default)]
output_index: u32,
item: OutputItem,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum OutputItem {
Message {
#[serde(default)]
id: Option<String>,
},
FunctionCall {
#[serde(default)]
id: Option<String>,
#[serde(default)]
call_id: Option<String>,
#[serde(default)]
name: Option<String>,
/// Some upstreams populate `arguments` already on the
/// `output_item.added` event for a fully-buffered tool call
/// (i.e. when the model finalised the call before the SSE
/// flush). Capture it so we can emit a single args delta.
#[serde(default)]
arguments: Option<String>,
},
/// `reasoning`, `web_search_call`, etc. We capture-and-ignore
/// any item we don't model; the decoder still emits the
/// outer events correctly.
#[serde(other)]
Unknown,
}
#[derive(Debug, Deserialize, Serialize)]
struct OutputTextDeltaEvent {
#[serde(default)]
item_id: Option<String>,
#[serde(default)]
output_index: u32,
#[serde(default)]
delta: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct FunctionCallArgumentsDeltaEvent {
#[serde(default)]
item_id: Option<String>,
#[serde(default)]
output_index: u32,
#[serde(default)]
delta: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct ResponseCompletedEvent {
response: ResponseShell,
}
#[derive(Debug, Deserialize, Serialize)]
struct ResponseShell {
#[serde(default)]
status: Option<String>,
#[serde(default)]
usage: Option<WireUsage>,
}
#[derive(Debug, Deserialize, Serialize)]
struct WireUsage {
#[serde(default)]
input_tokens: u64,
#[serde(default)]
output_tokens: u64,
#[serde(default)]
total_tokens: u64,
}
// ── Streaming decoder ───────────────────────────────────────────────
/// Translate the named-event Responses SSE into the provider-agnostic
/// [`CompletionEvent`] stream the agent loop expects. The decoder
/// holds per-stream state — output_index → tool-call-index plus
/// the next available tool-call slot — so it can fire
/// `ToolCallStart` exactly once per item.
fn decode_stream<S>(
sse: S,
cancel: CancellationToken,
) -> impl Stream<Item = anyhow::Result<CompletionEvent>>
where
S: Stream<
Item = Result<
eventsource_stream::Event,
eventsource_stream::EventStreamError<reqwest::Error>,
>,
> + Send
+ 'static,
{
async_stream::stream! {
let mut sse = Box::pin(sse);
// Maps an output_index that's a function_call to the tool-call
// slot we hand downstream. Lets us correlate later
// `function_call_arguments.delta` events back to the index
// we already announced on `output_item.added`.
let mut tool_index_by_output: HashMap<u32, usize> = HashMap::new();
let mut next_tool_index: usize = 0;
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::debug!("openai_responses: cancellation requested, ending stream");
break;
}
next = sse.next() => {
let Some(event) = next else { break };
let event = match event {
Ok(e) => e,
Err(e) => {
yield Err(anyhow::anyhow!("SSE transport: {e}"));
break;
}
};
// Event name lives on `event.event`; data is JSON.
let event_name = event.event.as_str();
let data = event.data.as_str();
match event_name {
"response.output_text.delta" => {
match serde_json::from_str::<OutputTextDeltaEvent>(data) {
Ok(d) if !d.delta.is_empty() => {
yield Ok(CompletionEvent::TextDelta(d.delta));
}
Ok(_) => {}
Err(e) => {
tracing::warn!(
error = %e,
raw = %data,
"openai_responses: failed to parse output_text.delta; skipping"
);
}
}
}
"response.output_item.added" => {
match serde_json::from_str::<OutputItemAddedEvent>(data) {
Ok(ev) => {
if let OutputItem::FunctionCall {
id,
call_id,
name,
arguments,
} = ev.item
{
let idx = next_tool_index;
next_tool_index += 1;
tool_index_by_output.insert(ev.output_index, idx);
// Prefer the user-facing
// `call_id` (what gets paired
// with tool results) over the
// internal item `id` when
// both are present. Falls
// back to a synthetic id so
// history bookkeeping never
// breaks.
let final_id = call_id
.or(id)
.unwrap_or_else(|| format!("call_{idx}"));
let final_name = name.unwrap_or_default();
yield Ok(CompletionEvent::ToolCallStart {
index: idx,
id: final_id,
name: final_name,
});
// Some upstreams attach the
// fully-buffered arguments on
// the `output_item.added`
// event itself (rare; happens
// when the model finalised
// before the SSE flush).
// Emit as a single args
// delta if present.
if let Some(args) = arguments
&& !args.is_empty()
{
yield Ok(CompletionEvent::ToolCallArgsDelta {
index: idx,
args_delta: args,
});
}
}
}
Err(e) => {
tracing::warn!(
error = %e,
raw = %data,
"openai_responses: failed to parse output_item.added; skipping"
);
}
}
}
"response.function_call_arguments.delta" => {
match serde_json::from_str::<FunctionCallArgumentsDeltaEvent>(data) {
Ok(ev) => {
let Some(&idx) = tool_index_by_output.get(&ev.output_index)
else {
// Args delta for an item we
// never saw an `output_item.added`
// for. Could happen if the
// upstream reordered events;
// log + skip.
tracing::warn!(
output_index = ev.output_index,
"openai_responses: function_call_arguments.delta for unknown output_index"
);
continue;
};
if !ev.delta.is_empty() {
yield Ok(CompletionEvent::ToolCallArgsDelta {
index: idx,
args_delta: ev.delta,
});
}
}
Err(e) => {
tracing::warn!(
error = %e,
raw = %data,
"openai_responses: failed to parse function_call_arguments.delta; skipping"
);
}
}
}
"response.completed" => {
// Final event. Pull usage + status off
// the response shell. Status maps:
// "completed" → no special handling
// (caller treats as EndTurn),
// "incomplete" → length stop.
let (reason, usage) =
match serde_json::from_str::<ResponseCompletedEvent>(data) {
Ok(ev) => {
let reason = match ev.response.status.as_deref() {
Some("incomplete") => Some("length".to_string()),
_ => Some("stop".to_string()),
};
let usage = ev.response.usage.map(|u| UsageStats {
prompt_tokens: u.input_tokens,
completion_tokens: u.output_tokens,
total_tokens: u.total_tokens,
});
(reason, usage)
}
Err(e) => {
tracing::warn!(
error = %e,
raw = %data,
"openai_responses: failed to parse response.completed; ending stream with EndTurn"
);
(Some("stop".to_string()), None)
}
};
if let Some(u) = usage {
yield Ok(CompletionEvent::Usage(u));
}
yield Ok(CompletionEvent::Finish { reason });
break;
}
// Bookkeeping events we don't need to surface:
// response.created, response.in_progress,
// response.content_part.added/.done,
// response.output_text.done,
// response.output_item.done,
// response.function_call_arguments.done,
// response.reasoning_*. Logged at debug for
// wire-tracing.
other => {
tracing::trace!(
event = other,
"openai_responses: bookkeeping event"
);
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::ToolCall;
use crate::provider::{ImageData, MessagePart};
use futures::stream;
use url::Url;
fn ep() -> EndpointConfig {
EndpointConfig {
name: "test".into(),
base_url: Url::parse("http://localhost:9999/v1").unwrap(),
wire_api: crate::config::WireApi::OpenAiResponses,
default_model: None,
api_key: None,
api_key_env: None,
max_tokens: None,
context_window: None,
}
}
// ── encode_request ──────────────────────────────────────────────
#[test]
fn system_messages_collapse_to_instructions() {
let req = CompletionRequest {
model: "m".into(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text {
text: "you are helpful".into(),
},
},
Message {
role: Role::User,
content: MessageContent::Text { text: "hi".into() },
},
],
tools: vec![],
temperature: Some(0.7),
top_p: None,
max_tokens: Some(256),
};
let body = encode_request(&req);
assert_eq!(body["model"], "m");
assert_eq!(body["instructions"], "you are helpful");
assert_eq!(body["stream"], true);
assert_eq!(body["max_output_tokens"], 256);
assert_eq!(body["temperature"], 0.7);
let input = body["input"].as_array().unwrap();
// System message NOT echoed in input — it's only in
// instructions.
assert_eq!(input.len(), 1);
assert_eq!(input[0]["type"], "message");
assert_eq!(input[0]["role"], "user");
assert_eq!(input[0]["content"], "hi");
}
#[test]
fn multiple_system_messages_concatenate() {
let req = CompletionRequest {
model: "m".into(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text {
text: "first".into(),
},
},
Message {
role: Role::System,
content: MessageContent::Text {
text: "second".into(),
},
},
Message {
role: Role::User,
content: MessageContent::Text { text: "hi".into() },
},
],
tools: vec![],
temperature: None,
top_p: None,
max_tokens: None,
};
let body = encode_request(&req);
assert_eq!(body["instructions"], "first\n\nsecond");
}
#[test]
fn user_multipart_becomes_input_parts_array() {
let req = CompletionRequest {
model: "vl".into(),
messages: vec![Message {
role: Role::User,
content: MessageContent::MultiPart {
parts: vec![
MessagePart::Text {
text: "what's in this?".into(),
},
MessagePart::Image(ImageData {
mime_type: "image/png".into(),
data: "AAA=".into(),
uri: None,
}),
],
},
}],
tools: vec![],
temperature: None,
top_p: None,
max_tokens: None,
};
let body = encode_request(&req);
let content = &body["input"][0]["content"].as_array().unwrap().clone();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "input_text");
assert_eq!(content[0]["text"], "what's in this?");
assert_eq!(content[1]["type"], "input_image");
assert_eq!(content[1]["image_url"], "data:image/png;base64,AAA=");
}
#[test]
fn assistant_text_becomes_output_text_content_part() {
let req = CompletionRequest {
model: "m".into(),
messages: vec![
Message {
role: Role::User,
content: MessageContent::Text { text: "hi".into() },
},
Message {
role: Role::Assistant,
content: MessageContent::Text {
text: "hello there".into(),
},
},
Message {
role: Role::User,
content: MessageContent::Text {
text: "more".into(),
},
},
],
tools: vec![],
temperature: None,
top_p: None,
max_tokens: None,
};
let body = encode_request(&req);
let input = body["input"].as_array().unwrap();
assert_eq!(input.len(), 3);
assert_eq!(input[1]["type"], "message");
assert_eq!(input[1]["role"], "assistant");
assert_eq!(input[1]["content"][0]["type"], "output_text");
assert_eq!(input[1]["content"][0]["text"], "hello there");
}
#[test]
fn tool_calls_and_results_round_trip_via_function_call_items() {
let req = CompletionRequest {
model: "m".into(),
messages: vec![
Message {
role: Role::Assistant,
content: MessageContent::ToolCalls {
text: None,
calls: vec![ToolCall {
id: "call_42".into(),
name: "read_file".into(),
arguments: r#"{"path":"/etc/hostname"}"#.into(),
}],
},
},
Message {
role: Role::Tool,
content: MessageContent::ToolResult {
tool_call_id: "call_42".into(),
content: "host".into(),
},
},
],
tools: vec![],
temperature: None,
top_p: None,
max_tokens: None,
};
let body = encode_request(&req);
let input = body["input"].as_array().unwrap();
assert_eq!(input.len(), 2);
assert_eq!(input[0]["type"], "function_call");
assert_eq!(input[0]["call_id"], "call_42");
assert_eq!(input[0]["name"], "read_file");
assert_eq!(input[0]["arguments"], r#"{"path":"/etc/hostname"}"#);
assert_eq!(input[1]["type"], "function_call_output");
assert_eq!(input[1]["call_id"], "call_42");
assert_eq!(input[1]["output"], "host");
}
// ── decode_stream ───────────────────────────────────────────────
fn sse_event(name: &str, data: &str) -> eventsource_stream::Event {
eventsource_stream::Event {
id: String::new(),
retry: None,
event: name.into(),
data: data.into(),
}
}
async fn collect_events(
items: Vec<eventsource_stream::Event>,
) -> Vec<anyhow::Result<CompletionEvent>> {
let sse = stream::iter(
items
.into_iter()
.map(Ok::<_, eventsource_stream::EventStreamError<reqwest::Error>>),
);
let decoded = decode_stream(sse, CancellationToken::new());
decoded.collect().await
}
#[tokio::test]
async fn decodes_text_then_finish() {
let events = collect_events(vec![
sse_event("response.created", "{}"),
sse_event(
"response.output_text.delta",
r#"{"item_id":"msg_1","output_index":0,"delta":"hel"}"#,
),
sse_event(
"response.output_text.delta",
r#"{"item_id":"msg_1","output_index":0,"delta":"lo"}"#,
),
sse_event(
"response.completed",
r#"{"response":{"status":"completed","usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}"#,
),
])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
let mut iter = events.into_iter();
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "hel"));
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "lo"));
assert!(matches!(iter.next(), Some(CompletionEvent::Usage(u)) if u.total_tokens == 5));
assert!(matches!(
iter.next(),
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "stop"
));
assert!(iter.next().is_none());
}
#[tokio::test]
async fn empty_delta_is_dropped() {
let events = collect_events(vec![
sse_event(
"response.output_text.delta",
r#"{"item_id":"m","output_index":0,"delta":""}"#,
),
sse_event(
"response.completed",
r#"{"response":{"status":"completed"}}"#,
),
])
.await;
let mut completion_events = events.into_iter().map(|r| r.unwrap());
// First event MUST be the Finish — the empty delta dropped.
assert!(matches!(
completion_events.next(),
Some(CompletionEvent::Finish { .. })
));
}
#[tokio::test]
async fn incomplete_status_maps_to_length_finish_reason() {
let events = collect_events(vec![sse_event(
"response.completed",
r#"{"response":{"status":"incomplete"}}"#,
)])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
assert!(matches!(
events.last(),
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "length"
));
}
#[tokio::test]
async fn function_call_items_emit_toolcall_events() {
let events = collect_events(vec![
sse_event(
"response.output_item.added",
r#"{"output_index":0,"item":{"type":"function_call","id":"item_1","call_id":"call_xyz","name":"read_file"}}"#,
),
sse_event(
"response.function_call_arguments.delta",
r#"{"item_id":"item_1","output_index":0,"delta":"{\"path"}"#,
),
sse_event(
"response.function_call_arguments.delta",
r#"{"item_id":"item_1","output_index":0,"delta":"\":\"/etc/hostname\"}"}"#,
),
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
let mut iter = events.into_iter();
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallStart { index: 0, ref id, ref name })
if id == "call_xyz" && name == "read_file"
));
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
if args_delta == r#"{"path"#
));
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
if args_delta == r#"":"/etc/hostname"}"#
));
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
}
#[tokio::test]
async fn function_call_added_with_inline_arguments_emits_single_args_delta() {
// Some upstreams (rare) include the fully-buffered arguments
// on the `output_item.added` event when the model finalised
// the call before SSE flush. Verify both ToolCallStart and a
// single args delta fire.
let events = collect_events(vec![
sse_event(
"response.output_item.added",
r#"{"output_index":0,"item":{"type":"function_call","call_id":"call_a","name":"f","arguments":"{\"x\":1}"}}"#,
),
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
let mut iter = events.into_iter();
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallStart { .. })
));
assert!(matches!(
iter.next(),
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
if args_delta == r#"{"x":1}"#
));
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
}
#[tokio::test]
async fn cancellation_ends_stream_promptly() {
// Hand the decoder an empty stream + a triggered cancellation
// token; it should terminate without yielding anything.
let sse = stream::iter(Vec::<
Result<eventsource_stream::Event, eventsource_stream::EventStreamError<reqwest::Error>>,
>::new());
let cancel = CancellationToken::new();
cancel.cancel();
let decoded = decode_stream(sse, cancel);
let events: Vec<_> = decoded.collect().await;
assert!(events.is_empty());
}
#[tokio::test]
async fn malformed_event_payload_is_skipped() {
let events = collect_events(vec![
sse_event("response.output_text.delta", "{not valid json"),
sse_event(
"response.output_text.delta",
r#"{"item_id":"m","output_index":0,"delta":"ok"}"#,
),
sse_event(
"response.completed",
r#"{"response":{"status":"completed"}}"#,
),
])
.await;
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
// First text delta dropped; second one fires.
assert!(
events
.iter()
.any(|e| matches!(e, CompletionEvent::TextDelta(t) if t == "ok"))
);
// No errors yielded (parse failures are warn-and-skip).
assert!(
events
.iter()
.all(|e| !matches!(e, CompletionEvent::Finish { reason: None }))
);
}
#[test]
fn provider_construction_is_cheap() {
let _ = OpenAIResponsesProvider::new(ep()).unwrap();
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,188 +0,0 @@
//! Per-session state for the ACP agent loop.
//!
//! Concurrency:
//!
//! - [`SessionStore`] is an `Arc<RwLock<HashMap<SessionId, …>>>`. The map
//! itself is read-mostly: it changes only on `session/new` and never
//! shrinks during Stage 2, so an `RwLock` keeps concurrent reads
//! contention-free.
//! - Each session is wrapped in its own `Arc<Mutex<SessionState>>`. Holding
//! one session's lock doesn't block requests against any other session,
//! which matters once a client opens multiple sessions in parallel.
//!
//! All operations hold a lock only long enough to copy out (or mutate) the
//! state they need — never across an `await` that drives the upstream
//! provider stream.
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use agent_client_protocol::schema::{SessionId, SessionModeId};
use tokio::sync::{Mutex, RwLock};
use tokio_util::sync::CancellationToken;
use crate::provider::Message;
/// Mode id advertised as the gated default. Writes / bash prompt for
/// permission via `session/request_permission`.
pub const MODE_DEFAULT: &str = "default";
/// Mode id advertised as "auto-allow everything". Matches the
/// favorite name (`bypassPermissions`) Zed clients tend to reference.
pub const MODE_BYPASS: &str = "bypassPermissions";
/// Mode id for read-and-plan-only operation. The model may read files
/// and list directories freely, may write *only* into the per-project
/// plan directory under `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`,
/// and cannot run shell commands. Designed for "draft the
/// implementation plan, then I'll review and let you execute" flows.
pub const MODE_PLAN: &str = "plan";
/// State carried for a single ACP session.
///
/// Mutated under `Mutex<SessionState>`; never share a clone across
/// tasks expecting to see the same `cancel` token — clone the token
/// explicitly when handing it to the streaming task.
#[derive(Debug)]
pub struct SessionState {
/// Conversation history in chronological order (user / assistant
/// turns). The system prompt is *not* stored here — it's built
/// fresh per request so any cwd / config changes take effect.
pub history: Vec<Message>,
/// Working directory the client opened the session against. Used
/// by [`crate::prompt::build_system_prompt`] and (Stage 3) by
/// filesystem tools.
pub cwd: PathBuf,
/// Currently-selected model id. Format is either a bare model id
/// (resolved against the default endpoint) or `endpoint:model`.
/// Mutated by `session/set_model` in Stage 4; Stage 2 sets it
/// once at session creation and never changes it.
pub model_id: String,
/// Cancellation handle for the in-flight prompt, if any. A fresh
/// token is installed at the start of every `session/prompt`
/// request; `session/cancel` fires this one. Between prompts the
/// token is "spent" — firing it does nothing — which is fine,
/// `session/cancel` is a no-op when there's nothing to cancel.
pub cancel: CancellationToken,
/// Permission gating mode. Stage 3 advertises two ids in
/// `NewSessionResponse.modes`: [`MODE_DEFAULT`] (writes / bash
/// prompt the user) and [`MODE_BYPASS`] (auto-allow). Mutated by
/// `session/set_mode`.
pub mode_id: SessionModeId,
}
impl SessionState {
pub fn new(cwd: PathBuf, model_id: String) -> Self {
Self {
history: Vec::new(),
cwd,
model_id,
cancel: CancellationToken::new(),
mode_id: SessionModeId::new(MODE_DEFAULT),
}
}
}
/// Concurrent map of live sessions.
///
/// Cloning is cheap (`Arc` bump). Pass clones into every handler that
/// needs session access; never hold a clone across an `.await` that
/// could outlive the request.
pub type SessionStore = Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>;
/// Fresh, empty session store.
pub fn new_store() -> SessionStore {
Arc::new(RwLock::new(HashMap::new()))
}
/// Look up a session by id. Returns `None` if no such session is registered.
pub async fn get(store: &SessionStore, id: &SessionId) -> Option<Arc<Mutex<SessionState>>> {
store.read().await.get(id).cloned()
}
/// Register a fresh session. Overwrites any prior entry with the same id
/// (which should never happen — ids are uniquely generated by the agent).
pub async fn insert(store: &SessionStore, id: SessionId, state: SessionState) {
store.write().await.insert(id, Arc::new(Mutex::new(state)));
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{MessageContent, Role};
fn id(s: &str) -> SessionId {
SessionId::new(s)
}
#[tokio::test]
async fn insert_then_get_round_trip() {
let store = new_store();
let state = SessionState::new(PathBuf::from("/tmp"), "m".into());
insert(&store, id("s1"), state).await;
let got = get(&store, &id("s1")).await.expect("session present");
let locked = got.lock().await;
assert_eq!(locked.cwd, PathBuf::from("/tmp"));
assert_eq!(locked.model_id, "m");
assert!(locked.history.is_empty());
}
#[tokio::test]
async fn missing_session_is_none() {
let store = new_store();
assert!(get(&store, &id("nope")).await.is_none());
}
#[tokio::test]
async fn history_is_per_session() {
let store = new_store();
insert(
&store,
id("a"),
SessionState::new(PathBuf::from("/a"), "m".into()),
)
.await;
insert(
&store,
id("b"),
SessionState::new(PathBuf::from("/b"), "m".into()),
)
.await;
// Appending to a's history must not affect b's.
get(&store, &id("a"))
.await
.unwrap()
.lock()
.await
.history
.push(Message {
role: Role::User,
content: MessageContent::Text {
text: "hello".into(),
},
});
assert_eq!(
get(&store, &id("a"))
.await
.unwrap()
.lock()
.await
.history
.len(),
1
);
assert_eq!(
get(&store, &id("b"))
.await
.unwrap()
.lock()
.await
.history
.len(),
0
);
}
}

View File

@@ -1,462 +0,0 @@
//! On-disk session persistence for `session/load` support.
//!
//! Storage layout:
//!
//! ```text
//! $XDG_DATA_HOME/helexa-acp/sessions/{session_id}.json
//! ```
//!
//! (Fallback to `~/.local/share/helexa-acp/sessions/` when
//! `$XDG_DATA_HOME` is unset.) One JSON file per session. Writes
//! happen at the end of every `session/prompt` round through
//! [`save`], using tempfile-plus-rename so a crash mid-write can't
//! corrupt the store. Reads happen on `session/load` via [`load`].
//!
//! No compaction, no rotation: files accumulate until the user
//! cleans them up. That's deliberate — disk is cheap, and the
//! resume-on-restart workflow matters more than tidiness. The
//! [`SESSIONS_DIRNAME`] subdirectory is created lazily on first
//! save so an unprivileged install path never errors at startup.
use std::path::PathBuf;
use std::time::SystemTime;
use agent_client_protocol::schema::SessionId;
use serde::{Deserialize, Serialize};
use crate::provider::Message;
const APP_DIRNAME: &str = "helexa-acp";
const SESSIONS_DIRNAME: &str = "sessions";
const PLANS_DIRNAME: &str = "plans";
/// The shape persisted to disk for one session. Only what we can't
/// rebuild from the running config goes in here: the conversation
/// history, the mode toggle, the model id, and the cwd-at-creation.
///
/// `created_at` / `updated_at` are seconds-since-epoch — cheap to
/// compare, no third-party time crate, and stable across runs.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedSession {
pub session_id: String,
pub cwd: PathBuf,
pub model_id: String,
pub mode_id: String,
pub history: Vec<Message>,
pub created_at: u64,
pub updated_at: u64,
}
/// Resolve the directory that holds session JSON files. Honors
/// `$XDG_DATA_HOME`; falls back to `~/.local/share/helexa-acp/sessions/`.
/// Returns `None` if neither is resolvable (no `HOME` set — possible
/// in stripped-down container environments).
pub fn sessions_dir() -> Option<PathBuf> {
let base = std::env::var("XDG_DATA_HOME")
.ok()
.filter(|s| !s.is_empty())
.map(PathBuf::from)
.or_else(|| {
std::env::var("HOME")
.ok()
.map(|h| PathBuf::from(h).join(".local").join("share"))
})?;
Some(base.join(APP_DIRNAME).join(SESSIONS_DIRNAME))
}
/// Atomic save into the default sessions directory.
pub fn save(session: &PersistedSession) -> anyhow::Result<()> {
let dir = sessions_dir()
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
save_to_dir(&dir, session)
}
/// Load from the default sessions directory.
pub fn load(session_id: &SessionId) -> anyhow::Result<PersistedSession> {
let dir = sessions_dir()
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
load_from_dir(&dir, session_id)
}
/// Atomic save into an explicit directory. Writes to
/// `{id}.json.tmp` then renames over `{id}.json`. Creates the
/// target directory if it doesn't exist. Split from [`save`] so
/// unit tests can target a per-test scratch dir without mutating
/// process-global env vars.
pub fn save_to_dir(dir: &std::path::Path, session: &PersistedSession) -> anyhow::Result<()> {
std::fs::create_dir_all(dir).map_err(|e| anyhow::anyhow!("create {}: {e}", dir.display()))?;
let safe = sanitize_id(&session.session_id);
let final_path = dir.join(format!("{safe}.json"));
let tmp_path = dir.join(format!("{safe}.json.tmp"));
let json = serde_json::to_string_pretty(session)?;
std::fs::write(&tmp_path, json)
.map_err(|e| anyhow::anyhow!("write {}: {e}", tmp_path.display()))?;
std::fs::rename(&tmp_path, &final_path)
.map_err(|e| anyhow::anyhow!("rename → {}: {e}", final_path.display()))?;
Ok(())
}
/// Load from an explicit directory. Returns a friendly error
/// message when the session id has no file on disk so the caller
/// can map it to a clean ACP error response.
pub fn load_from_dir(
dir: &std::path::Path,
session_id: &SessionId,
) -> anyhow::Result<PersistedSession> {
let safe = sanitize_id(session_id.0.as_ref());
let path = dir.join(format!("{safe}.json"));
let bytes = std::fs::read(&path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
anyhow::anyhow!("no persisted session at {}", path.display())
} else {
anyhow::anyhow!("read {}: {e}", path.display())
}
})?;
let session: PersistedSession = serde_json::from_slice(&bytes)
.map_err(|e| anyhow::anyhow!("parse {}: {e}", path.display()))?;
Ok(session)
}
/// List all persisted sessions, optionally filtered by `cwd`. Used
/// by the `session/list` handler so a client (Zed) can find the
/// session that belongs to the workspace it's reopening.
///
/// `filter_cwd = None` returns every session on disk. `Some(path)`
/// returns only sessions whose persisted `cwd` is exactly equal.
///
/// Files that fail to parse are skipped with a warning rather than
/// aborting the whole list — one corrupt session shouldn't make
/// the resume picker unusable.
pub fn list(filter_cwd: Option<&std::path::Path>) -> anyhow::Result<Vec<PersistedSession>> {
let dir = sessions_dir()
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
list_in_dir(&dir, filter_cwd)
}
/// Explicit-dir variant for tests, mirroring [`save_to_dir`] /
/// [`load_from_dir`].
pub fn list_in_dir(
dir: &std::path::Path,
filter_cwd: Option<&std::path::Path>,
) -> anyhow::Result<Vec<PersistedSession>> {
let read = match std::fs::read_dir(dir) {
Ok(r) => r,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
Err(e) => return Err(anyhow::anyhow!("read_dir {}: {e}", dir.display())),
};
let mut out = Vec::new();
for entry in read.flatten() {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("json") {
continue;
}
match std::fs::read(&path).and_then(|bytes| {
serde_json::from_slice::<PersistedSession>(&bytes).map_err(std::io::Error::other)
}) {
Ok(session) => {
if let Some(want) = filter_cwd
&& session.cwd != want
{
continue;
}
out.push(session);
}
Err(e) => {
tracing::warn!(
path = %path.display(),
error = %e,
"store: skipping unparseable session file"
);
}
}
}
// Most-recent first by updated_at.
out.sort_by_key(|s| std::cmp::Reverse(s.updated_at));
Ok(out)
}
/// Seconds-since-epoch, saturating to 0 if the system clock is
/// behind epoch (which shouldn't happen but the type system
/// requires a fallible read).
pub fn now_secs() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
/// Root directory for plan-mode artefacts. Mirrors [`sessions_dir`]
/// but under `…/helexa-acp/plans/` so plans and conversation
/// transcripts are siblings, not nested.
pub fn plans_root() -> Option<PathBuf> {
sessions_dir().and_then(|s| s.parent().map(|p| p.join(PLANS_DIRNAME)))
}
/// Per-project plan directory:
/// `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. The id derives
/// from the session's cwd so plans for the same project survive
/// across cwd-changes (a `/home/foo/git/bar` ↔ symlinked
/// `/srv/checkout/bar` would technically diverge, accepted as a
/// won't-fix corner case).
pub fn plan_dir_for(cwd: &std::path::Path) -> Option<PathBuf> {
plans_root().map(|root| root.join(project_id_for(cwd)))
}
/// Deterministic, human-readable project identifier. Format:
/// `<basename>-<8-hex>` where the 8-hex suffix is FNV-1a of the
/// full path. Basename keeps the path skim-readable when poking
/// around `$XDG_DATA_HOME` by hand; the hash suffix disambiguates
/// repos that share a final path component (e.g. multiple
/// `/.../checkout/beat` checkouts).
///
/// FNV-1a rather than `std::collections::hash::DefaultHasher`
/// because the latter (SipHash) reseeds per process, so it'd give
/// us a different project_id on every run.
pub fn project_id_for(cwd: &std::path::Path) -> String {
let basename = cwd
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
let sanitised: String = basename
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect();
let hash = fnv1a_32(cwd.to_string_lossy().as_bytes());
format!("{sanitised}-{hash:08x}")
}
/// FNV-1a (32-bit). Deterministic, no third-party crate. Used for
/// project ids only — not cryptographic.
fn fnv1a_32(bytes: &[u8]) -> u32 {
let mut h: u32 = 0x811c_9dc5;
for b in bytes {
h ^= u32::from(*b);
h = h.wrapping_mul(0x0100_0193);
}
h
}
/// Format seconds-since-epoch as an ISO 8601 / RFC 3339 string
/// (`YYYY-MM-DDTHH:MM:SSZ`) for `SessionInfo.updated_at`. Returns
/// `None` for values outside the representable range, in which
/// case the caller should omit the field.
pub fn unix_to_iso8601(secs: u64) -> Option<String> {
use chrono::TimeZone;
let dt = chrono::Utc.timestamp_opt(secs as i64, 0).single()?;
Some(dt.to_rfc3339_opts(chrono::SecondsFormat::Secs, true))
}
/// Strip anything that isn't a safe filename character so a
/// mischievous (or just unconventional) session id can't escape
/// the sessions directory.
fn sanitize_id(id: &str) -> String {
id.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{MessageContent, Role};
/// Unique scratch dir per test invocation. We use this dir
/// directly with the `*_to_dir` / `*_from_dir` functions so
/// the tests never mutate `$XDG_DATA_HOME` — that env var
/// would race across the parallel test harness.
fn unique_dir() -> PathBuf {
let base = std::env::var("CARGO_TARGET_TMPDIR")
.ok()
.map(PathBuf::from)
.unwrap_or_else(std::env::temp_dir);
let pid = std::process::id();
let nanos = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
let dir = base.join(format!("helexa-acp-store-test-{pid}-{nanos}"));
std::fs::create_dir_all(&dir).expect("create test dir");
dir
}
fn sample(id: &str) -> PersistedSession {
PersistedSession {
session_id: id.into(),
cwd: PathBuf::from("/home/me/proj"),
model_id: "Qwen/Qwen3.6-27B".into(),
mode_id: "default".into(),
history: vec![
Message {
role: Role::User,
content: MessageContent::Text {
text: "hello".into(),
},
},
Message {
role: Role::Assistant,
content: MessageContent::Text { text: "hi".into() },
},
],
created_at: 1_700_000_000,
updated_at: 1_700_000_001,
}
}
#[test]
fn round_trip_save_then_load() {
let dir = unique_dir();
save_to_dir(&dir, &sample("hxa-1")).expect("save");
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
assert_eq!(loaded.session_id, "hxa-1");
assert_eq!(loaded.cwd, PathBuf::from("/home/me/proj"));
assert_eq!(loaded.history.len(), 2);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn load_missing_session_errors_with_not_found_message() {
let dir = unique_dir();
let err = load_from_dir(&dir, &SessionId::new("nope")).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("no persisted session"),
"want NotFound, got: {msg}"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_overwrites_existing_atomically() {
let dir = unique_dir();
save_to_dir(&dir, &sample("hxa-1")).expect("save");
let mut updated = sample("hxa-1");
updated.history.push(Message {
role: Role::User,
content: MessageContent::Text {
text: "third turn".into(),
},
});
updated.updated_at = 1_700_000_500;
save_to_dir(&dir, &updated).expect("re-save");
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
assert_eq!(loaded.history.len(), 3);
assert_eq!(loaded.updated_at, 1_700_000_500);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn save_then_load_preserves_tool_calls_and_results() {
use crate::provider::ToolCall;
let dir = unique_dir();
let mut session = sample("hxa-2");
session.history.push(Message {
role: Role::Assistant,
content: MessageContent::ToolCalls {
text: Some("calling".into()),
calls: vec![ToolCall {
id: "call_0".into(),
name: "read_file".into(),
arguments: r#"{"path":"/etc/hostname"}"#.into(),
}],
},
});
session.history.push(Message {
role: Role::Tool,
content: MessageContent::ToolResult {
tool_call_id: "call_0".into(),
content: "host".into(),
},
});
save_to_dir(&dir, &session).expect("save");
let loaded = load_from_dir(&dir, &SessionId::new("hxa-2")).expect("load");
assert_eq!(loaded.history.len(), 4);
match &loaded.history[2].content {
MessageContent::ToolCalls { calls, .. } => {
assert_eq!(calls[0].name, "read_file");
}
other => panic!("expected ToolCalls, got {other:?}"),
}
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn list_filters_by_cwd_and_sorts_recent_first() {
let dir = unique_dir();
let mut a = sample("a");
a.cwd = PathBuf::from("/home/me/proj-x");
a.updated_at = 1_700_000_010;
let mut b = sample("b");
b.cwd = PathBuf::from("/home/me/proj-x");
b.updated_at = 1_700_000_020;
let mut c = sample("c");
c.cwd = PathBuf::from("/home/me/elsewhere");
c.updated_at = 1_700_000_030;
save_to_dir(&dir, &a).unwrap();
save_to_dir(&dir, &b).unwrap();
save_to_dir(&dir, &c).unwrap();
let proj_x = PathBuf::from("/home/me/proj-x");
let list = list_in_dir(&dir, Some(&proj_x)).unwrap();
let ids: Vec<&str> = list.iter().map(|s| s.session_id.as_str()).collect();
// Filtered to proj-x; b before a because b is more recent.
assert_eq!(ids, vec!["b", "a"]);
let all = list_in_dir(&dir, None).unwrap();
assert_eq!(all.len(), 3);
// Global list still sorted recent-first across all cwds.
assert_eq!(all[0].session_id, "c");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn list_returns_empty_for_missing_dir() {
let dir = unique_dir().join("does-not-exist");
let list = list_in_dir(&dir, None).unwrap();
assert!(list.is_empty());
}
#[test]
fn list_skips_unparseable_files() {
let dir = unique_dir();
save_to_dir(&dir, &sample("good")).unwrap();
std::fs::write(dir.join("garbage.json"), b"{not valid json").unwrap();
let list = list_in_dir(&dir, None).unwrap();
// Garbage skipped; good survives.
assert_eq!(list.len(), 1);
assert_eq!(list[0].session_id, "good");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn iso8601_formats_unix_seconds() {
// 2024-01-01T00:00:00Z is 1704067200 unix seconds.
assert_eq!(
unix_to_iso8601(1_704_067_200),
Some("2024-01-01T00:00:00Z".into())
);
assert_eq!(unix_to_iso8601(0), Some("1970-01-01T00:00:00Z".into()));
}
#[test]
fn sanitize_id_rejects_path_traversal() {
// `../../etc/passwd` — 6 non-alnum chars before "etc"
// (`.`, `.`, `/`, `.`, `.`, `/`), one between, none
// after, none before nothing. Every disallowed char
// collapses to `_`.
assert_eq!(sanitize_id("../../etc/passwd"), "______etc_passwd");
assert_eq!(sanitize_id("ok-name_42"), "ok-name_42");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,300 +0,0 @@
//! Tool schemas sent to the upstream model on every completion.
//!
//! These are the OpenAI-function-style declarations the LLM sees in
//! `CompletionRequest.tools`; the runtime dispatch happens in
//! [`crate::tool_runner`]. Keeping declarations and execution in
//! separate modules makes it easy to add a tool without touching the
//! runner, and vice versa.
//!
//! Stage 3 ships five: filesystem read / write / edit, directory
//! listing, and `bash`. Image generation, web fetch, MCP-derived
//! tools, etc. are out of scope here.
use serde_json::json;
use crate::provider::ToolSpec;
pub const READ_FILE: &str = "read_file";
pub const WRITE_FILE: &str = "write_file";
pub const EDIT_FILE: &str = "edit_file";
pub const LIST_DIR: &str = "list_dir";
pub const BASH: &str = "bash";
/// Build the static tool list passed to the model on every prompt.
/// Cheap — the JSON Schema fragments are constructed each call but
/// the bodies are small constants. If this ever shows up in a
/// profile we can `OnceLock` the Vec.
pub fn all_tools() -> Vec<ToolSpec> {
vec![
ToolSpec {
name: READ_FILE.to_string(),
description: "Read the contents of a text file. Returns the file's text.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file."
},
"line": {
"type": "integer",
"description": "Optional 1-based line number to start reading from.",
"minimum": 1
},
"limit": {
"type": "integer",
"description": "Optional maximum number of lines to read.",
"minimum": 1
}
},
"required": ["path"],
"additionalProperties": false
}),
},
ToolSpec {
name: WRITE_FILE.to_string(),
description: "Write text content to a file, replacing any existing contents. \
Creates the file (and parent directories) if needed."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file."
},
"content": {
"type": "string",
"description": "Full new contents of the file."
}
},
"required": ["path", "content"],
"additionalProperties": false
}),
},
ToolSpec {
name: EDIT_FILE.to_string(),
description: "Replace one exact substring in a file with another. \
Fails if `old_text` does not appear in the file, or appears more than once. \
Use multiple edit_file calls for multiple edits."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the file."
},
"old_text": {
"type": "string",
"description": "Exact text fragment to replace. Must be unique within the file."
},
"new_text": {
"type": "string",
"description": "Replacement text."
}
},
"required": ["path", "old_text", "new_text"],
"additionalProperties": false
}),
},
ToolSpec {
name: LIST_DIR.to_string(),
description:
"List the entries of a directory. Returns names and a (f|d|l) kind per entry."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Absolute path to the directory."
}
},
"required": ["path"],
"additionalProperties": false
}),
},
ToolSpec {
name: BASH.to_string(),
description: "Run a shell command via `sh -c`. \
Returns combined stdout+stderr and the exit status. \
The command runs in the session's working directory unless `cwd` is given."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Shell command line, evaluated by `sh -c`."
},
"cwd": {
"type": "string",
"description": "Optional absolute path to run the command from."
}
},
"required": ["command"],
"additionalProperties": false
}),
},
]
}
/// Try to infer which tool was intended from the shape of an
/// `arguments` object alone. Used by the agent when the model
/// emits a `<tool_call>` whose JSON has the right arguments but a
/// missing or invalid top-level `name` field — a recurring
/// Qwen3.6-27B failure mode.
///
/// Returns `Some(name)` only when the argument keys uniquely match
/// exactly one tool in the catalogue. Ambiguous shapes (`{path}`
/// alone could be either [`READ_FILE`] or [`LIST_DIR`]) return
/// `None` so the caller surfaces a Failed-card and lets the model
/// retry rather than guessing wrong.
///
/// Inference table (key set → tool):
///
/// | Keys | Tool |
/// |---------------------------------------|--------------|
/// | `{command}` or `{command, cwd}` | `bash` |
/// | `{path, content}` | `write_file` |
/// | `{path, old_text, new_text}` | `edit_file` |
/// | `{path}` / `{path, line}` / `{path, line, limit}` | *ambiguous* — None |
/// | (anything else) | None |
pub fn infer_tool_name(arguments: &serde_json::Value) -> Option<&'static str> {
let obj = arguments.as_object()?;
let keys: std::collections::HashSet<&str> = obj.keys().map(|s| s.as_str()).collect();
// `command` is unique to bash. Allow the optional `cwd` arg
// alongside but nothing else (any unrecognised keys → bail and
// let the model retry rather than misroute).
if keys.contains("command") && keys.iter().all(|k| matches!(*k, "command" | "cwd")) {
return Some(BASH);
}
// `content` is unique to write_file.
if keys.contains("content") && keys.contains("path") && keys.len() == 2 {
return Some(WRITE_FILE);
}
// `old_text` + `new_text` are unique to edit_file.
if keys.contains("old_text")
&& keys.contains("new_text")
&& keys.contains("path")
&& keys.len() == 3
{
return Some(EDIT_FILE);
}
// `{path}` / `{path, line}` / `{path, line, limit}` overlap
// between read_file (file contents) and list_dir (directory
// contents). No safe inference — refuse.
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_tools_has_five_named_entries() {
let tools = all_tools();
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert_eq!(
names,
vec![READ_FILE, WRITE_FILE, EDIT_FILE, LIST_DIR, BASH]
);
}
#[test]
fn infer_bash_from_command_only() {
let args = serde_json::json!({"command": "ls /tmp"});
assert_eq!(infer_tool_name(&args), Some(BASH));
}
#[test]
fn infer_bash_from_command_and_cwd() {
let args = serde_json::json!({"command": "ls", "cwd": "/tmp"});
assert_eq!(infer_tool_name(&args), Some(BASH));
}
#[test]
fn infer_bash_from_mkdir_like_real_failure() {
// Lifted verbatim from the agent failure that motivated
// this helper (helexa-acp.log @ 10:03:11).
let args = serde_json::json!({
"command": "mkdir -p /home/grenade/git/beat/beat/doc/plan/{01-discovery,02-segmentation,03-description,04-summary,05-output}"
});
assert_eq!(infer_tool_name(&args), Some(BASH));
}
#[test]
fn infer_write_file() {
let args = serde_json::json!({"path": "/tmp/x", "content": "hi"});
assert_eq!(infer_tool_name(&args), Some(WRITE_FILE));
}
#[test]
fn infer_edit_file() {
let args = serde_json::json!({
"path": "/tmp/x", "old_text": "a", "new_text": "b"
});
assert_eq!(infer_tool_name(&args), Some(EDIT_FILE));
}
#[test]
fn refuse_ambiguous_path_only() {
let args = serde_json::json!({"path": "/tmp/x"});
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn refuse_ambiguous_path_with_optionals() {
// read_file accepts these optionals; list_dir doesn't —
// but Qwen wouldn't reliably emit them either, so we
// can't use their presence to disambiguate. Refuse.
let args = serde_json::json!({"path": "/tmp/x", "line": 1, "limit": 50});
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn refuse_command_with_extra_unknown_keys() {
// Defence in depth: an unrecognised key alongside
// `command` means we don't really know what tool the
// model wanted; refuse rather than guess.
let args = serde_json::json!({"command": "ls", "extra": "?"});
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn refuse_empty_args() {
let args = serde_json::json!({});
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn refuse_non_object_args() {
let args = serde_json::json!("not an object");
assert_eq!(infer_tool_name(&args), None);
}
#[test]
fn every_tool_has_an_object_parameter_schema() {
for tool in all_tools() {
let ty = tool.parameters.get("type").and_then(|v| v.as_str());
assert_eq!(
ty,
Some("object"),
"tool {} parameters.type must be \"object\"",
tool.name
);
assert!(
tool.parameters.get("properties").is_some(),
"tool {} missing properties",
tool.name
);
assert!(
tool.parameters.get("required").is_some(),
"tool {} missing required list",
tool.name
);
}
}
}

View File

@@ -1,38 +0,0 @@
[package]
name = "helexa-bench"
version.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
[[bin]]
name = "helexa-bench"
path = "src/main.rs"
[dependencies]
cortex-core = { workspace = true }
tokio = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
figment = { workspace = true }
anyhow = { workspace = true }
async-trait = { workspace = true }
clap = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
chrono = { workspace = true }
futures = { workspace = true }
tokio-stream = { workspace = true }
eventsource-stream = { workspace = true }
# SQLite system-of-record. `bundled` compiles SQLite from source so the
# binary has no libsqlite3 runtime dependency — matches the project's
# single-static-binary packaging.
rusqlite = { version = "0.32", features = ["bundled"] }
[dev-dependencies]
axum = { workspace = true }
# Jail (isolated cwd + env) for config tests.
figment = { workspace = true, features = ["test"] }

View File

@@ -1,159 +0,0 @@
//! Outbound calls to a benchmark target: build identity, host discovery,
//! and warm-model enumeration. Neuron targets use the native neuron API;
//! `openai` targets use the OpenAI-compatible surface (preliminary).
use crate::config::{TargetConfig, TargetKind};
use anyhow::{Context, Result};
use cortex_core::build_info::BuildInfo;
use cortex_core::discovery::DiscoveryResponse;
use cortex_core::harness::ModelInfo;
use cortex_core::openai::ModelsResponse;
use std::time::Duration;
/// How long to wait on the cheap metadata polls (version/discovery/models).
const META_TIMEOUT: Duration = Duration::from_secs(10);
pub struct TargetClient {
http: reqwest::Client,
}
impl TargetClient {
pub fn new(request_timeout: Duration) -> Result<Self> {
let http = reqwest::Client::builder()
.timeout(request_timeout)
.build()
.context("building HTTP client")?;
Ok(TargetClient { http })
}
pub fn http(&self) -> &reqwest::Client {
&self.http
}
/// Chat-completions URL for the target.
pub fn chat_url(&self, target: &TargetConfig) -> String {
let base = target.endpoint.trim_end_matches('/');
match target.kind {
// neuron exposes OpenAI routes under /v1.
TargetKind::Neuron => format!("{base}/v1/chat/completions"),
// openai endpoint is the /v1 base already (bench.py convention).
TargetKind::Openai => format!("{base}/chat/completions"),
}
}
/// Build identity. Neuron: `GET /version`. Openai: a synthetic
/// placeholder keyed by `"external"` so the version-aware skip logic
/// treats it as one stable build (comparison runs are manual anyway).
pub async fn fetch_version(&self, target: &TargetConfig) -> Result<BuildInfo> {
match target.kind {
TargetKind::Neuron => {
let base = target.endpoint.trim_end_matches('/');
let info = self
.http
.get(format!("{base}/version"))
.timeout(META_TIMEOUT)
.send()
.await
.context("GET /version")?
.error_for_status()
.context("GET /version status")?
.json::<BuildInfo>()
.await
.context("decoding /version")?;
Ok(info)
}
TargetKind::Openai => {
let mut info = BuildInfo::unknown();
info.git_sha = "external".to_string();
Ok(info)
}
}
}
/// Host discovery (neuron only).
pub async fn fetch_discovery(
&self,
target: &TargetConfig,
) -> Result<Option<DiscoveryResponse>> {
if target.kind != TargetKind::Neuron {
return Ok(None);
}
let base = target.endpoint.trim_end_matches('/');
let disco = self
.http
.get(format!("{base}/discovery"))
.timeout(META_TIMEOUT)
.send()
.await
.context("GET /discovery")?
.error_for_status()
.context("GET /discovery status")?
.json::<DiscoveryResponse>()
.await
.context("decoding /discovery")?;
Ok(Some(disco))
}
/// Warm models — those ready to serve without a cold load.
///
/// Neuron: `GET /models` filtered to `status == "loaded"` (skips
/// `recovering`/`poisoned`). Openai: `GET /models`, honouring the
/// helexa `loaded` extension when present, else treating all listed
/// models as warm.
pub async fn warm_models(&self, target: &TargetConfig) -> Result<Vec<ModelInfo>> {
let base = target.endpoint.trim_end_matches('/');
match target.kind {
TargetKind::Neuron => {
let models = self
.http
.get(format!("{base}/models"))
.timeout(META_TIMEOUT)
.send()
.await
.context("GET /models")?
.error_for_status()
.context("GET /models status")?
.json::<Vec<ModelInfo>>()
.await
.context("decoding /models")?;
Ok(models
.into_iter()
.filter(|m| m.status == "loaded")
.collect())
}
TargetKind::Openai => {
let resp = self
.http
.get(format!("{base}/models"))
.timeout(META_TIMEOUT)
.send()
.await
.context("GET /models")?
.error_for_status()
.context("GET /models status")?
.json::<ModelsResponse>()
.await
.context("decoding /models")?;
Ok(resp
.data
.into_iter()
.filter(|m| {
// honour the helexa `loaded` extension if present
m.extra
.get("loaded")
.and_then(|v| v.as_bool())
.unwrap_or(true)
})
.map(|m| ModelInfo {
id: m.id,
harness: "openai".to_string(),
status: "loaded".to_string(),
devices: Vec::new(),
vram_used_mb: None,
capabilities: Vec::new(),
})
.collect())
}
}
}
}

View File

@@ -1,210 +0,0 @@
//! Bench configuration: loaded from `helexa-bench.toml` with figment,
//! `BENCH_`-prefixed env overrides (mirrors `NeuronConfig::load`).
use figment::{
Figment,
providers::{Env, Format, Toml},
};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::time::Duration;
/// Top-level bench config.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchConfig {
#[serde(default)]
pub bench: BenchSettings,
#[serde(default)]
pub scenarios: ScenarioConfig,
/// Endpoints to benchmark. At least one is required for `run`/`once`.
#[serde(default)]
pub targets: Vec<TargetConfig>,
}
/// Loop/timing knobs.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchSettings {
/// Pause between full sweeps of all targets.
#[serde(default = "default_sweep_interval")]
pub sweep_interval_secs: u64,
/// Target number of measured samples to record for a given
/// (target, build SHA, model, scenario). Once met, later sweeps skip
/// that cell — so a fully-sampled build costs only cheap version
/// polls until a new SHA ships.
#[serde(default = "default_samples")]
pub samples_per_version: u32,
/// Pause between successive measured iterations against one model.
#[serde(default = "default_iter_pause")]
pub iteration_pause_secs: u64,
/// Per-request timeout (cold lazy-loads can be slow; generous like
/// bench.py's 600s default).
#[serde(default = "default_timeout")]
pub request_timeout_secs: u64,
/// SQLite system-of-record path.
#[serde(default = "default_db_path")]
pub db_path: String,
}
impl Default for BenchSettings {
fn default() -> Self {
BenchSettings {
sweep_interval_secs: default_sweep_interval(),
samples_per_version: default_samples(),
iteration_pause_secs: default_iter_pause(),
request_timeout_secs: default_timeout(),
db_path: default_db_path(),
}
}
}
impl BenchSettings {
pub fn iteration_pause(&self) -> Duration {
Duration::from_secs(self.iteration_pause_secs)
}
pub fn request_timeout(&self) -> Duration {
Duration::from_secs(self.request_timeout_secs)
}
pub fn sweep_interval(&self) -> Duration {
Duration::from_secs(self.sweep_interval_secs)
}
}
/// Which scenarios to run and their shared parameters.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScenarioConfig {
/// Approximate prompt sizes (in tokens) — one chat-latency scenario
/// is generated per size, e.g. `chat:128`, `chat:4096`. This is the
/// per-cell dimension that the version-aware skip logic keys on.
#[serde(default = "default_prompt_sizes")]
pub prompt_sizes: Vec<u32>,
/// Max generated tokens per request.
#[serde(default = "default_max_tokens")]
pub max_tokens: u64,
}
impl Default for ScenarioConfig {
fn default() -> Self {
ScenarioConfig {
prompt_sizes: default_prompt_sizes(),
max_tokens: default_max_tokens(),
}
}
}
/// One endpoint to benchmark.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TargetConfig {
/// Stable label used as the engine column and in the DB.
pub name: String,
/// Which protocol/metadata surface the target exposes.
#[serde(default)]
pub kind: TargetKind,
/// Base URL. For `neuron`: the daemon root (e.g.
/// `http://beast.internal:13131`). For `openai`: the OpenAI `/v1`
/// base (e.g. `http://host:8080/v1`).
pub endpoint: String,
/// Optional display label override for reports (defaults to `name`).
#[serde(default)]
pub label: Option<String>,
}
impl TargetConfig {
pub fn display_label(&self) -> &str {
self.label.as_deref().unwrap_or(&self.name)
}
}
/// The two target surfaces. `neuron` gets rich build metadata and warm
/// model discovery via the native neuron API; `openai` is the seam for
/// later comparison against mistral.rs / llama.cpp / vLLM (phase 1
/// implements `neuron` fully; `openai` is preliminary plumbing).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TargetKind {
#[default]
Neuron,
Openai,
}
impl BenchConfig {
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<figment::Error>> {
Figment::new()
.merge(Toml::file(path))
.merge(Env::prefixed("BENCH_").split("__"))
.extract()
.map_err(Box::new)
}
}
fn default_sweep_interval() -> u64 {
1800
}
fn default_samples() -> u32 {
5
}
fn default_iter_pause() -> u64 {
2
}
fn default_timeout() -> u64 {
600
}
fn default_db_path() -> String {
"/var/lib/helexa-bench/bench.sqlite".to_string()
}
fn default_prompt_sizes() -> Vec<u32> {
vec![128, 4096]
}
fn default_max_tokens() -> u64 {
256
}
#[cfg(test)]
// Jail's closure must return figment::Result; the large-Err type is
// figment's, not ours, so suppress the lint here.
#[allow(clippy::result_large_err)]
mod tests {
use super::*;
use figment::Jail;
#[test]
fn loads_minimal_with_defaults() {
Jail::expect_with(|jail| {
jail.create_file(
"helexa-bench.toml",
r#"
[[targets]]
name = "beast"
endpoint = "http://beast.internal:13131"
"#,
)?;
let cfg = BenchConfig::load("helexa-bench.toml").unwrap();
assert_eq!(cfg.targets.len(), 1);
assert_eq!(cfg.targets[0].kind, TargetKind::Neuron);
assert_eq!(cfg.bench.samples_per_version, 5);
assert_eq!(cfg.scenarios.prompt_sizes, vec![128, 4096]);
Ok(())
});
}
#[test]
fn env_overrides_apply() {
Jail::expect_with(|jail| {
jail.create_file(
"helexa-bench.toml",
r#"
[bench]
samples_per_version = 3
[[targets]]
name = "benjy"
kind = "openai"
endpoint = "http://benjy:8080/v1"
"#,
)?;
jail.set_env("BENCH_BENCH__SAMPLES_PER_VERSION", "9");
let cfg = BenchConfig::load("helexa-bench.toml").unwrap();
assert_eq!(cfg.bench.samples_per_version, 9);
assert_eq!(cfg.targets[0].kind, TargetKind::Openai);
Ok(())
});
}
}

View File

@@ -1,12 +0,0 @@
//! helexa-bench — a continuous, version-aware benchmark harness for the
//! neuron fleet. It hits each neuron directly, exercises an extensible
//! scenario suite against every warm model, and records each run with
//! full build/version provenance into SQLite so improvements can be
//! tracked automatically across neuron implementation updates.
pub mod client;
pub mod config;
pub mod report;
pub mod scenario;
pub mod store;
pub mod sweep;

View File

@@ -1,126 +0,0 @@
//! helexa-bench CLI.
//!
//! - `run` — continuous daemon (systemd default): sweep, sleep, repeat.
//! - `once` — a single sweep, then exit (manual / CI).
//! - `report` — render the SQLite store as a results table.
//!
//! Runs on a single-threaded runtime: the workload is batch-1 sequential
//! (one request at a time, the regime we measure), and it lets the
//! SQLite connection live across awaits without `Sync` gymnastics.
use anyhow::{Context, Result};
use clap::{Parser, Subcommand};
use helexa_bench::config::BenchConfig;
use helexa_bench::report;
use helexa_bench::store::Store;
use helexa_bench::sweep::Sweeper;
use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(name = "helexa-bench")]
#[command(about = "Continuous version-aware benchmark harness for the neuron fleet")]
#[command(version)]
struct Cli {
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand)]
enum Command {
/// Run sweeps continuously, pausing `sweep_interval_secs` between them.
Run {
#[arg(short, long, default_value = "helexa-bench.toml")]
config: String,
},
/// Run a single sweep over all targets, then exit.
Once {
#[arg(short, long, default_value = "helexa-bench.toml")]
config: String,
},
/// Render recorded results. Uses `--db` if given, else the db_path
/// from `--config`.
Report {
#[arg(short, long, default_value = "helexa-bench.toml")]
config: String,
/// Override the SQLite path (skips reading the config file).
#[arg(long)]
db: Option<String>,
/// Output format.
#[arg(long, default_value = "md")]
format: Format,
},
}
#[derive(Clone, Copy, clap::ValueEnum)]
enum Format {
Md,
Json,
}
fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.init();
let cli = Cli::parse();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.context("building tokio runtime")?;
rt.block_on(run(cli))
}
async fn run(cli: Cli) -> Result<()> {
match cli.command {
Command::Run { config } => {
let cfg = load_config(&config)?;
require_targets(&cfg)?;
let sweeper = Sweeper::new(cfg)?;
tracing::info!("helexa-bench started; entering continuous sweep loop");
sweeper.run_forever().await
}
Command::Once { config } => {
let cfg = load_config(&config)?;
require_targets(&cfg)?;
let sweeper = Sweeper::new(cfg)?;
let summary = sweeper.run_once().await?;
tracing::info!(
measured = summary.measured,
skipped = summary.skipped,
failed = summary.failed,
unreachable = summary.targets_unreachable,
"single sweep complete"
);
Ok(())
}
Command::Report { config, db, format } => {
let db_path = match db {
Some(p) => p,
None => load_config(&config)?.bench.db_path,
};
let store = Store::open(&db_path)?;
let rows = store.report_rows()?;
let rendered = match format {
Format::Md => report::render_markdown(&rows),
Format::Json => report::render_json(&rows)?,
};
println!("{rendered}");
Ok(())
}
}
}
fn load_config(path: &str) -> Result<BenchConfig> {
BenchConfig::load(path)
.map_err(|e| anyhow::anyhow!("{e}"))
.with_context(|| format!("loading config {path}"))
}
fn require_targets(cfg: &BenchConfig) -> Result<()> {
if cfg.targets.is_empty() {
anyhow::bail!("no targets configured — add at least one [[targets]] entry");
}
Ok(())
}

View File

@@ -1,106 +0,0 @@
//! Render the SQLite store as a results table — the automated
//! replacement for hand-editing `doc/benchmarks.md`. Columns match that
//! doc: engine, model, prompt tok, TTFT (s), decode tok/s, total (s),
//! plus the build SHA each cell was measured against.
use crate::store::ReportRow;
use anyhow::Result;
pub fn render_markdown(rows: &[ReportRow]) -> String {
let mut out = String::new();
out.push_str(
"| engine | model | prompt tok | TTFT (s) | decode tok/s | total (s) | build | n |\n",
);
out.push_str("|---|---|---:|---:|---:|---:|---|---:|\n");
for r in rows {
let ptok = r
.prompt_tokens
.map(|t| t.to_string())
.unwrap_or_else(|| format!("~{}", r.prompt_size_approx));
out.push_str(&format!(
"| {} | {} | {} | {} | {} | {} | `{}` | {} |\n",
r.target_name,
r.model_id,
ptok,
fmt_opt(r.ttft_s_median, 3),
fmt_opt(r.decode_tps_median, 1),
fmt_opt(r.total_s_median, 3),
r.git_sha,
r.samples,
));
}
out
}
pub fn render_json(rows: &[ReportRow]) -> Result<String> {
let arr: Vec<serde_json::Value> = rows
.iter()
.map(|r| {
serde_json::json!({
"engine": r.target_name,
"model": r.model_id,
"scenario": r.scenario_id,
"prompt_size_approx": r.prompt_size_approx,
"prompt_tokens": r.prompt_tokens,
"ttft_s_median": r.ttft_s_median,
"decode_tps_median": r.decode_tps_median,
"total_s_median": r.total_s_median,
"git_sha": r.git_sha,
"samples": r.samples,
})
})
.collect();
Ok(serde_json::to_string_pretty(&arr)?)
}
fn fmt_opt(v: Option<f64>, places: usize) -> String {
match v {
Some(x) => format!("{x:.places$}"),
None => "".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn markdown_has_header_and_row() {
let rows = vec![ReportRow {
target_name: "beast".into(),
model_id: "Qwen/Qwen3.6-27B".into(),
scenario_id: "chat:128".into(),
prompt_size_approx: 128,
git_sha: "30d50d6".into(),
prompt_tokens: Some(130),
ttft_s_median: Some(0.123),
decode_tps_median: Some(45.6),
total_s_median: Some(1.234),
samples: 5,
}];
let md = render_markdown(&rows);
assert!(md.contains("| engine |"));
assert!(md.contains("beast"));
assert!(md.contains("`30d50d6`"));
assert!(md.contains("0.123"));
}
#[test]
fn missing_decode_renders_dash() {
let rows = vec![ReportRow {
target_name: "benjy".into(),
model_id: "m".into(),
scenario_id: "chat:128".into(),
prompt_size_approx: 128,
git_sha: "abc".into(),
prompt_tokens: None,
ttft_s_median: Some(0.1),
decode_tps_median: None,
total_s_median: Some(0.5),
samples: 1,
}];
let md = render_markdown(&rows);
assert!(md.contains("~128"));
assert!(md.contains(""));
}
}

View File

@@ -1,238 +0,0 @@
//! The extensible test suite.
//!
//! A [`Scenario`] puts one warm model through one shaped request and
//! reports operator-felt metrics (TTFT, decode tok/s, total). Phase 1
//! ships the chat-latency family ported faithfully from `script/bench.py`;
//! the trait is the seam for future families (vision, concurrency,
//! long-generation, cold-start) selected per model via [`Scenario::applies_to`].
use crate::config::ScenarioConfig;
use anyhow::{Context, Result, anyhow};
use async_trait::async_trait;
use cortex_core::harness::ModelInfo;
use cortex_core::openai::ChatCompletionChunk;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use serde_json::json;
use std::time::{Duration, Instant};
/// A paragraph of filler re-used to synthesise prompts of a target
/// approximate token count (~4 chars/token heuristic — close enough for
/// bucketing; real token counts are read back from the usage object).
/// Mirrors `script/bench.py::FILLER`.
const FILLER: &str = "The quick brown fox jumps over the lazy dog while the band plays \
a slow waltz in the background and somebody counts the beats. ";
/// `/no_think`: Qwen3-family soft switch keeping thinking models from
/// burning the token budget invisibly. Harmless for non-thinking models.
const QUESTION: &str = "\n\nRetell the scene above as a vivid story of about 300 words. /no_think";
/// Build a synthetic prompt of approximately `approx_tokens` tokens.
/// Ported from `bench.py::build_prompt`.
pub fn build_prompt(approx_tokens: u32) -> String {
let target_chars = (approx_tokens.max(16) as usize) * 4;
let reps = target_chars / FILLER.len() + 1;
let mut body = FILLER.repeat(reps);
body.truncate(target_chars);
body.push_str(QUESTION);
body
}
/// Per-request inputs shared by every scenario.
pub struct RunCtx<'a> {
pub client: &'a reqwest::Client,
/// Fully-qualified chat-completions URL for the target.
pub chat_url: String,
pub model_id: String,
pub max_tokens: u64,
pub timeout: Duration,
}
/// Operator-felt metrics for a single measured request.
#[derive(Debug, Clone)]
pub struct ScenarioMetrics {
/// Time to first content chunk (seconds).
pub ttft_s: f64,
/// Completion tokens / decode window. `None` when the window is too
/// short to be honest (≤ 200 ms), matching bench.py.
pub decode_tps: Option<f64>,
/// Wall-clock for the whole request (seconds).
pub total_s: f64,
/// Prompt tokens from the final `usage` object, if the server sent one.
pub prompt_tokens: Option<u64>,
/// Completion tokens: from `usage` when present, else content-chunk count.
pub completion_tokens: u64,
}
#[async_trait]
pub trait Scenario: Send + Sync {
/// Stable id, e.g. `chat:128`. Used as the version-aware skip key
/// dimension and recorded against every run.
fn id(&self) -> &str;
/// Approximate prompt size in tokens (the cell dimension), recorded
/// for reporting.
fn prompt_size(&self) -> u32;
/// Whether this scenario should run against the given model. Default
/// runs against everything; vision/audio scenarios will gate on
/// [`ModelInfo::capabilities`].
fn applies_to(&self, _model: &ModelInfo) -> bool {
true
}
/// Issue one shaped request and measure it.
async fn run(&self, ctx: &RunCtx) -> Result<ScenarioMetrics>;
}
/// Build the active scenario set from config. One chat-latency scenario
/// per configured prompt size.
pub fn build_scenarios(cfg: &ScenarioConfig) -> Vec<Box<dyn Scenario>> {
cfg.prompt_sizes
.iter()
.map(|&size| {
Box::new(ChatLatencyScenario {
id: format!("chat:{size}"),
approx_prompt_tokens: size,
}) as Box<dyn Scenario>
})
.collect()
}
/// Streamed single-request chat-completions latency probe — the batch-1
/// regime bench.py measures.
pub struct ChatLatencyScenario {
id: String,
approx_prompt_tokens: u32,
}
#[async_trait]
impl Scenario for ChatLatencyScenario {
fn id(&self) -> &str {
&self.id
}
fn prompt_size(&self) -> u32 {
self.approx_prompt_tokens
}
async fn run(&self, ctx: &RunCtx) -> Result<ScenarioMetrics> {
let prompt = build_prompt(self.approx_prompt_tokens);
let payload = json!({
"model": ctx.model_id,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": ctx.max_tokens,
"temperature": 0,
"stream": true,
"stream_options": {"include_usage": true},
});
let fut = stream_and_measure(ctx, &payload);
tokio::time::timeout(ctx.timeout, fut)
.await
.map_err(|_| anyhow!("request timed out after {:?}", ctx.timeout))?
}
}
/// The SSE-timing core, ported from `bench.py::one_run`. Kept free of the
/// `Scenario` trait so it's unit-testable against a mock byte stream.
async fn stream_and_measure(
ctx: &RunCtx<'_>,
payload: &serde_json::Value,
) -> Result<ScenarioMetrics> {
let start = Instant::now();
let resp = ctx
.client
.post(&ctx.chat_url)
.json(payload)
.send()
.await
.context("sending chat request")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(anyhow!("upstream returned {status}: {}", body.trim()));
}
let mut stream = resp.bytes_stream().eventsource();
let mut first: Option<Instant> = None;
let mut last: Option<Instant> = None;
let mut chunk_count: u64 = 0;
let mut prompt_tokens: Option<u64> = None;
let mut completion_tokens: Option<u64> = None;
while let Some(event) = stream.next().await {
let event = event.context("reading SSE stream")?;
let now = Instant::now();
let data = event.data.trim();
if data.is_empty() || data == "[DONE]" {
continue;
}
let chunk: ChatCompletionChunk = match serde_json::from_str(data) {
Ok(c) => c,
Err(_) => continue, // tolerate non-JSON keepalive frames
};
if let Some(choice) = chunk.choices.first()
&& choice
.delta
.get("content")
.and_then(|c| c.as_str())
.is_some_and(|s| !s.is_empty())
{
if first.is_none() {
first = Some(now);
}
last = Some(now);
chunk_count += 1;
}
if let Some(usage) = chunk.usage {
prompt_tokens = Some(usage.prompt_tokens);
completion_tokens = Some(usage.completion_tokens);
}
}
let end = Instant::now();
let first = first.ok_or_else(|| anyhow!("no content chunks received"))?;
// neuron emits one SSE chunk per visible token, so chunk_count is an
// engine-truth count when no usage frame is sent.
let tokens = completion_tokens.filter(|&t| t > 0).unwrap_or(chunk_count);
// decode rate is only meaningful over a real inter-chunk window.
let window = last
.filter(|&l| l > first)
.map(|l| (l - first).as_secs_f64())
.unwrap_or(0.0);
Ok(ScenarioMetrics {
ttft_s: (first - start).as_secs_f64(),
decode_tps: if window > 0.2 {
Some(tokens as f64 / window)
} else {
None
},
total_s: (end - start).as_secs_f64(),
prompt_tokens,
completion_tokens: tokens,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prompt_grows_with_token_target() {
let small = build_prompt(128);
let big = build_prompt(4096);
assert!(big.len() > small.len());
// ~4 chars/token + the trailing question.
assert!(small.len() >= 128 * 4);
assert!(small.ends_with("/no_think"));
}
#[test]
fn prompt_floor_for_tiny_targets() {
// max(approx,16) floor means even 0 yields a non-trivial prompt.
let p = build_prompt(0);
assert!(p.len() >= 16 * 4);
}
}

View File

@@ -1,400 +0,0 @@
//! SQLite system-of-record. One row per measured iteration, keyed so a
//! benchmark can be attributed to the exact neuron build that produced
//! it. Replaces hand edits to `doc/benchmarks.md`.
//!
//! Calls are synchronous (SQLite is local and the sweep is batch-1
//! sequential), so the connection is used inline between `await` points,
//! never held across one.
use anyhow::{Context, Result};
use rusqlite::{Connection, params};
use std::path::Path;
/// A single measured (or failed) iteration, with full provenance.
#[derive(Debug, Clone)]
pub struct RunRecord {
pub ts: String, // RFC3339
// target
pub target_name: String,
pub target_kind: String,
pub endpoint: String,
// host (from /discovery)
pub hostname: Option<String>,
pub driver_version: Option<String>,
pub cuda_version: Option<String>,
pub gpus_json: Option<String>,
// neuron build (from /version)
pub git_sha: String,
pub git_sha_long: Option<String>,
pub package_version: String,
pub git_dirty: bool,
pub build_timestamp: Option<String>,
pub rustc_version: Option<String>,
pub profile: Option<String>,
pub features_json: String,
pub candle_version: Option<String>,
// bench's own build
pub bench_version: String,
pub bench_sha: String,
// model
pub model_id: String,
pub harness: String,
pub capabilities_json: String,
pub devices_json: String,
// scenario
pub scenario_id: String,
pub prompt_size_approx: u32,
pub prompt_tokens_actual: Option<u64>,
pub max_tokens: u64,
// metrics
pub ttft_s: Option<f64>,
pub decode_tps: Option<f64>,
pub total_s: Option<f64>,
pub completion_tokens: Option<u64>,
// outcome
pub ok: bool,
pub error: Option<String>,
}
pub struct Store {
conn: Connection,
}
impl Store {
/// Open (creating parent dirs + schema as needed).
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
if let Some(parent) = path.parent()
&& !parent.as_os_str().is_empty()
{
std::fs::create_dir_all(parent)
.with_context(|| format!("creating db dir {}", parent.display()))?;
}
let conn = Connection::open(path)
.with_context(|| format!("opening sqlite db {}", path.display()))?;
Self::init(&conn)?;
Ok(Store { conn })
}
/// In-memory store for tests.
#[cfg(test)]
pub fn open_in_memory() -> Result<Self> {
let conn = Connection::open_in_memory()?;
Self::init(&conn)?;
Ok(Store { conn })
}
fn init(conn: &Connection) -> Result<()> {
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts TEXT NOT NULL,
target_name TEXT NOT NULL,
target_kind TEXT NOT NULL,
endpoint TEXT NOT NULL,
hostname TEXT,
driver_version TEXT,
cuda_version TEXT,
gpus_json TEXT,
git_sha TEXT NOT NULL,
git_sha_long TEXT,
package_version TEXT NOT NULL,
git_dirty INTEGER NOT NULL,
build_timestamp TEXT,
rustc_version TEXT,
profile TEXT,
features_json TEXT NOT NULL,
candle_version TEXT,
bench_version TEXT NOT NULL,
bench_sha TEXT NOT NULL,
model_id TEXT NOT NULL,
harness TEXT NOT NULL,
capabilities_json TEXT NOT NULL,
devices_json TEXT NOT NULL,
scenario_id TEXT NOT NULL,
prompt_size_approx INTEGER NOT NULL,
prompt_tokens_actual INTEGER,
max_tokens INTEGER NOT NULL,
ttft_s REAL,
decode_tps REAL,
total_s REAL,
completion_tokens INTEGER,
ok INTEGER NOT NULL,
error TEXT
);
-- The version-aware skip query keys on this tuple. scenario_id
-- encodes the prompt size (chat:<n>), so it subsumes the cell.
CREATE INDEX IF NOT EXISTS idx_runs_cell
ON runs (target_name, git_sha, model_id, scenario_id, ok);
"#,
)
.context("initialising sqlite schema")?;
Ok(())
}
/// Count successful samples already recorded for a cell. Only `ok`
/// rows count toward the per-version target so transient failures
/// don't permanently starve a cell.
pub fn count_samples(
&self,
target_name: &str,
git_sha: &str,
model_id: &str,
scenario_id: &str,
) -> Result<u32> {
let n: i64 = self.conn.query_row(
"SELECT COUNT(*) FROM runs WHERE target_name=?1 AND git_sha=?2 \
AND model_id=?3 AND scenario_id=?4 AND ok=1",
params![target_name, git_sha, model_id, scenario_id],
|row| row.get(0),
)?;
Ok(n as u32)
}
pub fn insert_run(&self, r: &RunRecord) -> Result<()> {
self.conn.execute(
"INSERT INTO runs (
ts, target_name, target_kind, endpoint,
hostname, driver_version, cuda_version, gpus_json,
git_sha, git_sha_long, package_version, git_dirty,
build_timestamp, rustc_version, profile, features_json, candle_version,
bench_version, bench_sha,
model_id, harness, capabilities_json, devices_json,
scenario_id, prompt_size_approx, prompt_tokens_actual, max_tokens,
ttft_s, decode_tps, total_s, completion_tokens,
ok, error
) VALUES (
?1, ?2, ?3, ?4,
?5, ?6, ?7, ?8,
?9, ?10, ?11, ?12,
?13, ?14, ?15, ?16, ?17,
?18, ?19,
?20, ?21, ?22, ?23,
?24, ?25, ?26, ?27,
?28, ?29, ?30, ?31,
?32, ?33
)",
params![
r.ts,
r.target_name,
r.target_kind,
r.endpoint,
r.hostname,
r.driver_version,
r.cuda_version,
r.gpus_json,
r.git_sha,
r.git_sha_long,
r.package_version,
r.git_dirty as i64,
r.build_timestamp,
r.rustc_version,
r.profile,
r.features_json,
r.candle_version,
r.bench_version,
r.bench_sha,
r.model_id,
r.harness,
r.capabilities_json,
r.devices_json,
r.scenario_id,
r.prompt_size_approx,
r.prompt_tokens_actual,
r.max_tokens,
r.ttft_s,
r.decode_tps,
r.total_s,
r.completion_tokens,
r.ok as i64,
r.error,
],
)?;
Ok(())
}
/// One reportable cell: the median metrics over the most-recently-seen
/// build SHA for each (target, model, scenario).
pub fn report_rows(&self) -> Result<Vec<ReportRow>> {
// For each (target, model, scenario), find the SHA of the latest
// successful run, then median that SHA's samples.
let mut stmt = self.conn.prepare(
"SELECT target_name, model_id, scenario_id, prompt_size_approx, git_sha,
ttft_s, decode_tps, total_s, prompt_tokens_actual
FROM runs
WHERE ok=1
ORDER BY target_name, model_id, scenario_id, id",
)?;
let rows = stmt.query_map([], |row| {
Ok(RawRow {
target_name: row.get(0)?,
model_id: row.get(1)?,
scenario_id: row.get(2)?,
prompt_size_approx: row.get(3)?,
git_sha: row.get(4)?,
ttft_s: row.get(5)?,
decode_tps: row.get(6)?,
total_s: row.get(7)?,
prompt_tokens_actual: row.get(8)?,
})
})?;
let raws: Vec<RawRow> = rows.collect::<rusqlite::Result<_>>()?;
Ok(aggregate(raws))
}
}
struct RawRow {
target_name: String,
model_id: String,
scenario_id: String,
prompt_size_approx: u32,
git_sha: String,
ttft_s: Option<f64>,
decode_tps: Option<f64>,
total_s: Option<f64>,
prompt_tokens_actual: Option<u64>,
}
/// An aggregated cell ready for the report table.
#[derive(Debug, Clone, PartialEq)]
pub struct ReportRow {
pub target_name: String,
pub model_id: String,
pub scenario_id: String,
pub prompt_size_approx: u32,
pub git_sha: String,
pub prompt_tokens: Option<u64>,
pub ttft_s_median: Option<f64>,
pub decode_tps_median: Option<f64>,
pub total_s_median: Option<f64>,
pub samples: usize,
}
/// Group by (target, model, scenario), keep only the latest SHA's rows
/// (latest = the SHA of the last-inserted row, since input is id-ordered),
/// and median each metric.
fn aggregate(raws: Vec<RawRow>) -> Vec<ReportRow> {
use std::collections::BTreeMap;
// key -> (latest_sha, rows for that sha)
let mut groups: BTreeMap<(String, String, String), Vec<RawRow>> = BTreeMap::new();
for r in raws {
groups
.entry((
r.target_name.clone(),
r.model_id.clone(),
r.scenario_id.clone(),
))
.or_default()
.push(r);
}
let mut out = Vec::new();
for ((target_name, model_id, scenario_id), rows) in groups {
// id-ordered, so the last row carries the latest SHA.
let latest_sha = rows.last().map(|r| r.git_sha.clone()).unwrap_or_default();
let cell: Vec<&RawRow> = rows.iter().filter(|r| r.git_sha == latest_sha).collect();
let prompt_size_approx = cell.first().map(|r| r.prompt_size_approx).unwrap_or(0);
out.push(ReportRow {
target_name,
model_id,
scenario_id,
prompt_size_approx,
git_sha: latest_sha,
prompt_tokens: cell.iter().find_map(|r| r.prompt_tokens_actual),
ttft_s_median: median(cell.iter().filter_map(|r| r.ttft_s)),
decode_tps_median: median(cell.iter().filter_map(|r| r.decode_tps)),
total_s_median: median(cell.iter().filter_map(|r| r.total_s)),
samples: cell.len(),
});
}
out
}
fn median(values: impl Iterator<Item = f64>) -> Option<f64> {
let mut v: Vec<f64> = values.collect();
if v.is_empty() {
return None;
}
v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
// lo == hi for odd lengths (the middle element); they straddle the
// centre for even lengths. Avoids a `% 2` branch.
let lo = (v.len() - 1) / 2;
let hi = v.len() / 2;
Some((v[lo] + v[hi]) / 2.0)
}
#[cfg(test)]
mod tests {
use super::*;
fn rec(target: &str, sha: &str, model: &str, scenario: &str, ok: bool) -> RunRecord {
RunRecord {
ts: "2026-06-13T00:00:00Z".into(),
target_name: target.into(),
target_kind: "neuron".into(),
endpoint: "http://x:13131".into(),
hostname: Some("x".into()),
driver_version: None,
cuda_version: None,
gpus_json: None,
git_sha: sha.into(),
git_sha_long: None,
package_version: "0.1.16".into(),
git_dirty: false,
build_timestamp: None,
rustc_version: None,
profile: None,
features_json: "[]".into(),
candle_version: None,
bench_version: "0.1.16".into(),
bench_sha: "deadbee".into(),
model_id: model.into(),
harness: "candle".into(),
capabilities_json: "[]".into(),
devices_json: "[]".into(),
scenario_id: scenario.into(),
prompt_size_approx: 128,
prompt_tokens_actual: Some(130),
max_tokens: 256,
ttft_s: Some(0.1),
decode_tps: Some(50.0),
total_s: Some(1.0),
completion_tokens: Some(50),
ok,
error: if ok { None } else { Some("boom".into()) },
}
}
#[test]
fn counts_only_successful_samples() {
let s = Store::open_in_memory().unwrap();
s.insert_run(&rec("beast", "abc", "m", "chat:128", true))
.unwrap();
s.insert_run(&rec("beast", "abc", "m", "chat:128", true))
.unwrap();
s.insert_run(&rec("beast", "abc", "m", "chat:128", false))
.unwrap();
assert_eq!(s.count_samples("beast", "abc", "m", "chat:128").unwrap(), 2);
// Different SHA is a different cell.
assert_eq!(s.count_samples("beast", "xyz", "m", "chat:128").unwrap(), 0);
}
#[test]
fn report_uses_latest_sha_per_cell() {
let s = Store::open_in_memory().unwrap();
// old build
s.insert_run(&rec("beast", "old", "m", "chat:128", true))
.unwrap();
// new build, two samples
let mut r = rec("beast", "new", "m", "chat:128", true);
r.ttft_s = Some(0.2);
s.insert_run(&r).unwrap();
r.ttft_s = Some(0.4);
s.insert_run(&r).unwrap();
let rows = s.report_rows().unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].git_sha, "new");
assert_eq!(rows[0].samples, 2);
assert!((rows[0].ttft_s_median.unwrap() - 0.3).abs() < 1e-9);
}
}

View File

@@ -1,250 +0,0 @@
//! The version-aware sweep loop.
//!
//! Each sweep visits every configured target, polls its build identity
//! and warm models, and tops up benchmark samples per
//! (target, build SHA, model, scenario) to `samples_per_version`. Cells
//! already at target are skipped — so once every neuron's current build
//! is fully sampled, sweeps cost only the cheap metadata polls until a
//! new SHA ships. Runs are recorded to SQLite with full provenance.
use crate::client::TargetClient;
use crate::config::{BenchConfig, TargetConfig, TargetKind};
use crate::scenario::{RunCtx, build_scenarios};
use crate::store::{RunRecord, Store};
use anyhow::Result;
use cortex_core::build_info::BuildInfo;
use cortex_core::discovery::DiscoveryResponse;
use cortex_core::harness::ModelInfo;
/// helexa-bench's own build version.
fn bench_version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
/// helexa-bench's own build SHA, injected by CI via `HELEXA_BUILD_SHA`
/// at compile time; `"unknown"` for ad-hoc local builds.
fn bench_sha() -> String {
option_env!("HELEXA_BUILD_SHA")
.filter(|s| !s.is_empty())
.unwrap_or("unknown")
.to_string()
}
#[derive(Debug, Default, Clone)]
pub struct SweepSummary {
pub measured: usize,
pub skipped: usize,
pub failed: usize,
pub targets_unreachable: usize,
}
pub struct Sweeper {
cfg: BenchConfig,
client: TargetClient,
store: Store,
}
impl Sweeper {
pub fn new(cfg: BenchConfig) -> Result<Self> {
let client = TargetClient::new(cfg.bench.request_timeout())?;
let store = Store::open(&cfg.bench.db_path)?;
Ok(Sweeper { cfg, client, store })
}
/// Run sweeps forever, pausing `sweep_interval` between them.
pub async fn run_forever(&self) -> ! {
loop {
match self.run_once().await {
Ok(s) => tracing::info!(
measured = s.measured,
skipped = s.skipped,
failed = s.failed,
unreachable = s.targets_unreachable,
"sweep complete"
),
Err(e) => tracing::error!(error = %format!("{e:#}"), "sweep errored"),
}
tracing::debug!(
secs = self.cfg.bench.sweep_interval_secs,
"sleeping until next sweep"
);
tokio::time::sleep(self.cfg.bench.sweep_interval()).await;
}
}
/// One full pass over all targets.
pub async fn run_once(&self) -> Result<SweepSummary> {
let mut summary = SweepSummary::default();
for target in &self.cfg.targets {
if let Err(e) = self.sweep_target(target, &mut summary).await {
summary.targets_unreachable += 1;
tracing::warn!(target = %target.name, error = %format!("{e:#}"), "target skipped");
}
}
Ok(summary)
}
async fn sweep_target(&self, target: &TargetConfig, summary: &mut SweepSummary) -> Result<()> {
let build = self.client.fetch_version(target).await?;
let discovery = self.client.fetch_discovery(target).await.unwrap_or(None);
let models = self.client.warm_models(target).await?;
tracing::info!(
target = %target.name,
sha = %build.git_sha,
warm_models = models.len(),
"sweeping target"
);
let scenarios = build_scenarios(&self.cfg.scenarios);
for model in &models {
for scenario in scenarios.iter().filter(|s| s.applies_to(model)) {
let have = self.store.count_samples(
&target.name,
&build.git_sha,
&model.id,
scenario.id(),
)?;
let need = self.cfg.bench.samples_per_version.saturating_sub(have);
if need == 0 {
summary.skipped += 1;
tracing::debug!(
target = %target.name, model = %model.id, scenario = scenario.id(),
sha = %build.git_sha, "cell already satisfied, skipping"
);
continue;
}
let ctx = RunCtx {
client: self.client.http(),
chat_url: self.client.chat_url(target),
model_id: model.id.clone(),
max_tokens: self.cfg.scenarios.max_tokens,
timeout: self.cfg.bench.request_timeout(),
};
// One unmeasured warmup when the cell is empty (matches
// bench.py — first run after a load hits cold caches).
if have == 0 {
tracing::debug!(model = %model.id, scenario = scenario.id(), "warmup run");
let _ = scenario.run(&ctx).await;
}
for i in 0..need {
match scenario.run(&ctx).await {
Ok(m) => {
let rec = self.build_record(
target,
&build,
discovery.as_ref(),
model,
scenario.id(),
scenario.prompt_size(),
Ok(&m),
);
self.store.insert_run(&rec)?;
summary.measured += 1;
tracing::info!(
target = %target.name, model = %model.id, scenario = scenario.id(),
ttft_s = m.ttft_s, decode_tps = ?m.decode_tps, total_s = m.total_s,
"{}/{} recorded", have + i + 1, self.cfg.bench.samples_per_version
);
}
Err(e) => {
let msg = format!("{e:#}");
let rec = self.build_record(
target,
&build,
discovery.as_ref(),
model,
scenario.id(),
scenario.prompt_size(),
Err(&msg),
);
self.store.insert_run(&rec)?;
summary.failed += 1;
tracing::warn!(
target = %target.name, model = %model.id, scenario = scenario.id(),
error = %msg, "iteration failed"
);
}
}
tokio::time::sleep(self.cfg.bench.iteration_pause()).await;
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn build_record(
&self,
target: &TargetConfig,
build: &BuildInfo,
discovery: Option<&DiscoveryResponse>,
model: &ModelInfo,
scenario_id: &str,
prompt_size: u32,
result: Result<&crate::scenario::ScenarioMetrics, &str>,
) -> RunRecord {
let (ok, error, ttft, decode, total, prompt_tokens, completion) = match result {
Ok(m) => (
true,
None,
Some(m.ttft_s),
m.decode_tps,
Some(m.total_s),
m.prompt_tokens,
Some(m.completion_tokens),
),
Err(e) => (false, Some(e.to_string()), None, None, None, None, None),
};
RunRecord {
ts: chrono::Utc::now().to_rfc3339(),
target_name: target.name.clone(),
target_kind: kind_str(target.kind).to_string(),
endpoint: target.endpoint.clone(),
hostname: discovery.map(|d| d.hostname.clone()),
driver_version: discovery.and_then(|d| d.driver_version.clone()),
cuda_version: discovery.and_then(|d| d.cuda_version.clone()),
gpus_json: discovery
.map(|d| serde_json::to_string(&d.devices).unwrap_or_else(|_| "[]".to_string())),
git_sha: build.git_sha.clone(),
git_sha_long: build.git_sha_long.clone(),
package_version: build.package_version.clone(),
git_dirty: build.git_dirty,
build_timestamp: build.build_timestamp.clone(),
rustc_version: build.rustc_version.clone(),
profile: build.profile.clone(),
features_json: serde_json::to_string(&build.features)
.unwrap_or_else(|_| "[]".to_string()),
candle_version: build.candle_version.clone(),
bench_version: bench_version(),
bench_sha: bench_sha(),
model_id: model.id.clone(),
harness: model.harness.clone(),
capabilities_json: serde_json::to_string(&model.capabilities)
.unwrap_or_else(|_| "[]".to_string()),
devices_json: serde_json::to_string(&model.devices)
.unwrap_or_else(|_| "[]".to_string()),
scenario_id: scenario_id.to_string(),
prompt_size_approx: prompt_size,
prompt_tokens_actual: prompt_tokens,
max_tokens: self.cfg.scenarios.max_tokens,
ttft_s: ttft,
decode_tps: decode,
total_s: total,
completion_tokens: completion,
ok,
error,
}
}
}
fn kind_str(kind: TargetKind) -> &'static str {
match kind {
TargetKind::Neuron => "neuron",
TargetKind::Openai => "openai",
}
}

View File

@@ -1,132 +0,0 @@
//! End-to-end sweep against a mock neuron: a sweep records samples, a
//! second sweep skips the satisfied cell, and bumping the reported build
//! SHA resumes fresh sampling.
use axum::Router;
use axum::extract::State;
use axum::http::header;
use axum::response::{IntoResponse, Json};
use axum::routing::{get, post};
use helexa_bench::config::{BenchConfig, BenchSettings, ScenarioConfig, TargetConfig, TargetKind};
use helexa_bench::sweep::Sweeper;
use serde_json::json;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct MockState {
sha: Arc<Mutex<String>>,
}
async fn version(State(s): State<MockState>) -> Json<serde_json::Value> {
let sha = s.sha.lock().unwrap().clone();
Json(json!({
"package_version": "0.1.16",
"git_sha": sha,
"git_dirty": false,
"features": ["cuda", "cudnn"],
"candle_version": "0.10.2",
}))
}
async fn discovery() -> Json<serde_json::Value> {
Json(json!({
"hostname": "mock-beast",
"os": "Linux",
"kernel": "6.19.0",
"cuda_version": "13.0",
"driver_version": "580.159",
"devices": [{"index": 0, "name": "RTX 5090", "vram_total_mb": 32614, "compute_capability": "12.0"}],
"harnesses": ["candle"],
}))
}
async fn models() -> Json<serde_json::Value> {
Json(json!([
{"id": "Qwen/Qwen3.6-27B", "harness": "candle", "status": "loaded", "devices": [0], "capabilities": ["text"]},
// A non-warm model the bench must ignore.
{"id": "Qwen/cold", "harness": "candle", "status": "recovering", "devices": [0]},
]))
}
async fn chat() -> impl IntoResponse {
let body = concat!(
"data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n",
"data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\" world\"},\"finish_reason\":null}]}\n\n",
"data: {\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":130,\"completion_tokens\":2,\"total_tokens\":132}}\n\n",
"data: [DONE]\n\n",
);
([(header::CONTENT_TYPE, "text/event-stream")], body)
}
async fn spawn_mock(sha: &str) -> (String, Arc<Mutex<String>>) {
let shared = Arc::new(Mutex::new(sha.to_string()));
let state = MockState {
sha: shared.clone(),
};
let app = Router::new()
.route("/version", get(version))
.route("/discovery", get(discovery))
.route("/models", get(models))
.route("/v1/chat/completions", post(chat))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(format!("http://{addr}"), shared)
}
fn config_for(endpoint: String, db_path: String) -> BenchConfig {
BenchConfig {
bench: BenchSettings {
sweep_interval_secs: 1,
samples_per_version: 2,
iteration_pause_secs: 0,
request_timeout_secs: 30,
db_path,
},
scenarios: ScenarioConfig {
prompt_sizes: vec![128], // single scenario keeps assertions simple
max_tokens: 16,
},
targets: vec![TargetConfig {
name: "mock".into(),
kind: TargetKind::Neuron,
endpoint,
label: None,
}],
}
}
#[tokio::test]
async fn sweep_records_skips_and_resumes_on_new_sha() {
let (endpoint, sha_handle) = spawn_mock("aaaaaaa").await;
// Unique db path per run (bound port is unique).
let port = endpoint.rsplit(':').next().unwrap();
let db_path = std::env::temp_dir().join(format!("helexa-bench-it-{port}.sqlite"));
let _ = std::fs::remove_file(&db_path);
let db_str = db_path.to_string_lossy().to_string();
let sweeper = Sweeper::new(config_for(endpoint, db_str)).unwrap();
// First sweep: one warm model × one scenario × 2 samples.
let s1 = sweeper.run_once().await.unwrap();
assert_eq!(s1.measured, 2, "should record samples_per_version samples");
assert_eq!(s1.skipped, 0);
assert_eq!(s1.failed, 0);
// Second sweep at same SHA: cell satisfied, nothing measured.
let s2 = sweeper.run_once().await.unwrap();
assert_eq!(s2.measured, 0, "satisfied cell must be skipped");
assert_eq!(s2.skipped, 1);
// Bump the reported build SHA: a new cell → fresh sampling resumes.
*sha_handle.lock().unwrap() = "bbbbbbb".to_string();
let s3 = sweeper.run_once().await.unwrap();
assert_eq!(s3.measured, 2, "new SHA must resume sampling");
assert_eq!(s3.skipped, 0);
let _ = std::fs::remove_file(&db_path);
}

View File

@@ -12,36 +12,6 @@ path = "src/lib.rs"
name = "neuron"
path = "src/main.rs"
[features]
default = []
# Enables CUDA acceleration in candle and the cudarc/nccl bindings the
# TP worker pool uses. Without this feature, candle compiles for CPU
# only, Device::new_cuda calls fall back to CPU, and TP Init/sanity
# requests return Error{kind="cuda_feature_not_enabled"}.
cuda = [
"candle-core/cuda",
"candle-core/nccl",
"candle-nn/cuda",
"candle-transformers/cuda",
"dep:cudarc",
"dep:half",
"dep:cudaforge",
]
# Use cuDNN for convolution / attention kernels. Requires CUDA.
cudnn = [
"cuda",
"candle-core/cudnn",
"candle-nn/cudnn",
"candle-transformers/cudnn",
]
# FlashAttention kernels. Requires CUDA.
flash-attn = [
"cuda",
"candle-transformers/flash-attn",
]
# Reserved for GPU-only integration tests in later stages.
cuda-integration = ["cuda"]
[dependencies]
cortex-core.workspace = true
tokio.workspace = true
@@ -54,70 +24,9 @@ tracing-subscriber.workspace = true
anyhow.workspace = true
async-trait.workspace = true
clap.workspace = true
thiserror.workspace = true
futures.workspace = true
tokio-stream.workspace = true
figment.workspace = true
toml.workspace = true
# Parallel in-situ quantization (#1): fans candle's per-block k-quant
# math across the CPU pool at model-load time. Already in the tree
# transitively via candle-core.
rayon = "1"
# candle for in-process inference. CUDA support is gated behind the
# crate's `cuda` feature (default off) so the workspace builds on
# non-CUDA hosts and CI runners.
candle-core = "0.10.2"
candle-nn = "0.10.2"
candle-transformers = "0.10.2"
# Direct dep on cudarc (matching candle's transitive version) so the
# TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on
# the `cuda` feature; same toolchain requirement as candle's CUDA path.
cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] }
# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle
# storages. Matches candle-core's pinned major version to avoid double-
# compiling the `half` crate at conflicting versions.
half = { version = "2.5", optional = true }
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
hf-hub = { version = "0.4", features = ["tokio"] }
# Jinja-compatible template renderer for the model's chat template
# (standalone `chat_template.jinja` or `tokenizer_config.json::chat_template`).
# Hugging Face's chat templates lean on Python string semantics; we
# bridge them with `minijinja-contrib`'s `pycompat` callback (str
# methods like `startswith`/`split`/`strip`) plus a `raise_exception`
# global. Features: `builtins` for `is defined` / `default`; `json`
# for `tojson`; `serde` so we can hand it a serde_json::Value context.
minijinja = { version = "2", features = ["builtins", "json", "serde"] }
# Python-compatibility shim: the Qwen3-VL / Qwen3.6 template uses
# `content.startswith(...)`, `.endswith(...)`, `.split(...)`,
# `.rstrip(...)`, `.lstrip(...)` — Python str methods minijinja doesn't
# implement natively. `pycompat::unknown_method_callback` supplies them.
minijinja-contrib = { version = "2", features = ["pycompat"] }
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
# tp `fused_load` module to read per-rank slices of fused QKV tensors
# without materialising the full tensor on device.
safetensors = "0.7"
# Vision capability for Qwen3.6 (Stage A of the vision plan in
# doc/vision-qwen3_6-spec.md). `image` decodes PNG/JPEG/etc from
# the bytes embedded in `data:image/...;base64,...` content parts;
# `base64` does the URI decode. Default-features off on `image` to
# avoid pulling in audio/video formats we don't need.
image = { version = "0.25", default-features = false, features = ["png", "jpeg", "webp", "bmp", "gif"] }
base64 = "0.22"
[dev-dependencies]
tokio = { workspace = true, features = ["test-util"] }
reqwest.workspace = true
tempfile = "3"
[build-dependencies]
# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a`
# under the `cuda` feature. Matches mistralrs's upstream build setup
# (their `mistralrs-core/build.rs` uses the same constructor).
cudaforge = { version = "0.1", optional = true }
[package.metadata.docs.rs]
# Skip the CUDA path on docs.rs (it lacks nvcc).
no-default-features = true

View File

@@ -1,196 +0,0 @@
//! Build script: capture build/version metadata for `GET /version`,
//! and (under the `cuda` feature) compile the CUDA kernels in
//! `src/cuda/*.cu` into a static library and link it.
//!
//! The CUDA portion is patterned on
//! `EricLBuehler/mistral.rs::mistralrs-core/build.rs` — same
//! `cudaforge::KernelBuilder` invocation, same NVCC flag set.
use std::process::Command;
fn main() {
emit_build_metadata();
#[cfg(feature = "cuda")]
{
use std::path::PathBuf;
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/cuda/");
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let mut builder = cudaforge::KernelBuilder::new()
.source_glob("src/cuda/*.cu")
.out_dir(&build_dir)
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--compiler-options")
.arg("-fPIC");
// sm_<80 doesn't have bf16 intrinsics for WMMA — gate the
// bf16-only kernels off in that case. (Mirrors upstream.)
if let Some(compute_cap) = builder.get_compute_cap()
&& compute_cap < 80
{
builder = builder.arg("-DNO_BF16_KERNEL");
}
let target = std::env::var("TARGET").unwrap();
let out_file = if target.contains("msvc") {
build_dir.join("neuroncuda.lib")
} else {
build_dir.join("libneuroncuda.a")
};
builder
.build_lib(out_file)
.expect("neuron cuda build failed");
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=neuroncuda");
println!("cargo:rustc-link-lib=dylib=cudart");
if target.contains("msvc") {
// No extra runtime library needed.
} else if target.contains("apple")
|| target.contains("freebsd")
|| target.contains("openbsd")
{
println!("cargo:rustc-link-lib=dylib=c++");
} else if target.contains("android") {
println!("cargo:rustc-link-lib=dylib=c++_shared");
} else {
println!("cargo:rustc-link-lib=dylib=stdc++");
}
}
}
/// Emit `cargo:rustc-env=` vars consumed by `env!()` in `src/version.rs`
/// so the daemon can report its own build identity from `GET /version`.
///
/// We re-run only when HEAD moves or the SHA override changes — not on
/// every compile — so the captured timestamp is stable for a given
/// build input rather than churning on each `cargo build`.
fn emit_build_metadata() {
println!("cargo:rerun-if-env-changed=HELEXA_BUILD_SHA");
println!("cargo:rerun-if-changed=.git/HEAD");
// A detached/normal HEAD points at a ref whose file is what actually
// changes on commit; watch the packed-refs fallback too.
println!("cargo:rerun-if-changed=.git/packed-refs");
// SHA: prefer the CI/RPM-injected override (tarball builds have no
// .git), then fall back to git, then to "unknown".
let (sha_short, sha_long, dirty) = match std::env::var("HELEXA_BUILD_SHA") {
Ok(s) if !s.trim().is_empty() => {
let s = s.trim().to_string();
let short = s.chars().take(7).collect::<String>();
(short, Some(s), false)
}
_ => {
let long = git(&["rev-parse", "HEAD"]);
let short = git(&["rev-parse", "--short", "HEAD"]);
let dirty = git(&["status", "--porcelain"])
.map(|s| !s.trim().is_empty())
.unwrap_or(false);
match short {
Some(short) => (short, long, dirty),
None => ("unknown".to_string(), None, false),
}
}
};
println!("cargo:rustc-env=HELEXA_GIT_SHA={sha_short}");
println!(
"cargo:rustc-env=HELEXA_GIT_SHA_LONG={}",
sha_long.unwrap_or_default()
);
println!("cargo:rustc-env=HELEXA_GIT_DIRTY={dirty}");
// RFC3339 build timestamp. `date` is universally present on the
// Linux hosts neuron targets; empty if it ever isn't.
let ts = Command::new("date")
.args(["-u", "+%Y-%m-%dT%H:%M:%SZ"])
.output()
.ok()
.filter(|o| o.status.success())
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
.unwrap_or_default();
println!("cargo:rustc-env=HELEXA_BUILD_TIMESTAMP={ts}");
// Compiler version: cargo sets $RUSTC to the rustc it invokes.
let rustc = std::env::var("RUSTC").unwrap_or_else(|_| "rustc".to_string());
let rustc_version = Command::new(rustc)
.arg("--version")
.output()
.ok()
.filter(|o| o.status.success())
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
.unwrap_or_default();
println!("cargo:rustc-env=HELEXA_RUSTC_VERSION={rustc_version}");
println!(
"cargo:rustc-env=HELEXA_BUILD_PROFILE={}",
std::env::var("PROFILE").unwrap_or_default()
);
println!(
"cargo:rustc-env=HELEXA_TARGET={}",
std::env::var("TARGET").unwrap_or_default()
);
// Enabled features: cargo exports CARGO_FEATURE_<NAME> for each.
// Reverse the mangling (uppercase, '-'→'_') best-effort for display.
let mut features: Vec<String> = std::env::vars()
.filter_map(|(k, _)| k.strip_prefix("CARGO_FEATURE_").map(|f| f.to_string()))
.map(|f| f.to_lowercase().replace('_', "-"))
// `default` is the meta-feature, not a perf-relevant flag.
.filter(|f| f != "default")
.collect();
features.sort();
println!("cargo:rustc-env=HELEXA_FEATURES={}", features.join(","));
println!(
"cargo:rustc-env=HELEXA_CANDLE_VERSION={}",
candle_version().unwrap_or_default()
);
}
fn git(args: &[&str]) -> Option<String> {
let out = Command::new("git").args(args).output().ok()?;
if !out.status.success() {
return None;
}
let s = String::from_utf8_lossy(&out.stdout).trim().to_string();
if s.is_empty() { None } else { Some(s) }
}
/// Best-effort: read the locked `candle-core` version from the workspace
/// `Cargo.lock` (two levels up from this crate). Returns `None` if the
/// lockfile is absent (e.g. some packaging flows) or the entry isn't
/// found.
fn candle_version() -> Option<String> {
let manifest = std::env::var("CARGO_MANIFEST_DIR").ok()?;
let lock = std::path::Path::new(&manifest)
.join("..")
.join("..")
.join("Cargo.lock");
println!("cargo:rerun-if-changed={}", lock.display());
let text = std::fs::read_to_string(lock).ok()?;
// Cargo.lock entries are `[[package]]\nname = "x"\nversion = "y"`.
let mut in_candle = false;
for line in text.lines() {
let line = line.trim();
if line == "[[package]]" {
in_candle = false;
} else if line == "name = \"candle-core\"" {
in_candle = true;
} else if in_candle && let Some(rest) = line.strip_prefix("version = \"") {
return Some(rest.trim_end_matches('"').to_string());
}
}
None
}

View File

@@ -1,93 +0,0 @@
//! Activation-time pre-warm progress tracking.
//!
//! Wraps the [`ActivationStatus`] snapshot in an async RwLock so the
//! background pre-warm task can update it per-model while the
//! `/health` handler reads coherent snapshots. The tracker exists
//! because `default_models` loading moved from synchronous-before-bind
//! to background-after-bind on 2026-05-26: the listener is up
//! immediately, but `/health` now needs to tell callers which of the
//! configured defaults are still warming.
use cortex_core::discovery::{ActivationState, ActivationStatus, PreWarmFailure};
use cortex_core::harness::ModelSpec;
use tokio::sync::RwLock;
/// Shared, async-safe handle to the daemon's activation progress.
///
/// Construct once in `main` with the configured `default_models` so
/// the initial `pending` list matches the spec; clone the `Arc` into
/// the `NeuronState` for HTTP handlers and into the spawned pre-warm
/// task for updates.
pub struct ActivationTracker {
inner: RwLock<ActivationStatus>,
}
impl ActivationTracker {
/// Build a tracker primed with one entry per spec. An empty spec
/// list yields a `Ready` tracker — no point reporting PreWarming
/// when there's nothing queued.
pub fn new(default_models: &[ModelSpec]) -> Self {
let pending: Vec<String> = default_models.iter().map(|s| s.model_id.clone()).collect();
let state = if pending.is_empty() {
ActivationState::Ready
} else {
ActivationState::PreWarming
};
Self {
inner: RwLock::new(ActivationStatus {
state,
pending,
in_progress: None,
completed: vec![],
failed: vec![],
}),
}
}
/// Mark a model as in-progress: remove it from `pending`, set as
/// `in_progress`. Called immediately before `registry.load_model`.
pub async fn start_loading(&self, model_id: &str) {
let mut s = self.inner.write().await;
s.pending.retain(|m| m != model_id);
s.in_progress = Some(model_id.to_string());
}
/// Mark a model as completed: clear `in_progress` (if it matches),
/// append to `completed`.
pub async fn complete_loading(&self, model_id: &str) {
let mut s = self.inner.write().await;
if s.in_progress.as_deref() == Some(model_id) {
s.in_progress = None;
}
s.completed.push(model_id.to_string());
}
/// Mark a model as failed: clear `in_progress` (if it matches),
/// append a `PreWarmFailure` carrying the rendered error chain.
pub async fn fail_loading(&self, model_id: &str, error: &str) {
let mut s = self.inner.write().await;
if s.in_progress.as_deref() == Some(model_id) {
s.in_progress = None;
}
s.failed.push(PreWarmFailure {
model_id: model_id.to_string(),
error: error.to_string(),
});
}
/// Flip the high-level `state` to `Ready` once the pre-warm task
/// is done iterating. Pending should be empty by this point; if a
/// caller bails early it's a stuck activation and the operator
/// will see entries in `pending` even with `state=ready` — that's
/// a useful diagnostic, not an inconsistency to scrub.
pub async fn mark_ready(&self) {
let mut s = self.inner.write().await;
s.state = ActivationState::Ready;
s.in_progress = None;
}
/// Cheap clone of the current state for the `/health` handler.
pub async fn snapshot(&self) -> ActivationStatus {
self.inner.read().await.clone()
}
}

View File

@@ -1,63 +1,34 @@
//! HTTP API handlers for the neuron daemon.
use crate::activation::ActivationTracker;
use crate::harness::HarnessRegistry;
use crate::harness::candle::{CandleHarness, InferenceError};
use crate::harness::preflight::PreflightError;
use crate::health::HealthCache;
use crate::wire::{openai_chat, openai_responses};
use axum::Router;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Json};
use axum::routing::{get, post};
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
use cortex_core::harness::ModelSpec;
use cortex_core::openai::{ChatCompletionRequest, MessageContent};
use cortex_core::responses::{ResponsesRequest, ResponsesUsage};
use futures::stream::{self, StreamExt};
use serde_json::{Value, json};
use std::convert::Infallible;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
/// Shared state for the neuron HTTP server.
pub struct NeuronState {
pub discovery: DiscoveryResponse,
pub health_cache: Arc<HealthCache>,
pub registry: RwLock<HarnessRegistry>,
/// Typed handle to the candle harness for inference routes. Cached at
/// startup so `/v1/chat/completions` doesn't have to hold the registry
/// read lock or perform dyn-Trait dispatch per request.
pub candle: Option<Arc<CandleHarness>>,
/// Activation-time pre-warm progress. Updated by the background
/// `load_default_models` task, read by the `/health` handler.
pub activation: Arc<ActivationTracker>,
}
/// Build the neuron API router.
pub fn neuron_routes() -> Router<Arc<NeuronState>> {
Router::new()
.route("/version", get(version_handler))
.route("/discovery", get(discovery_handler))
.route("/health", get(health_handler))
.route("/models", get(list_models))
.route("/models/load", post(load_model))
.route("/models/unload", post(unload_model))
.route("/models/{model_id}/endpoint", get(model_endpoint))
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/responses", post(responses))
}
/// `GET /version` — the daemon's own build identity (git SHA, enabled
/// features, rustc/candle versions). Static for the process lifetime, so
/// no state is touched. This is the canonical "which build is live"
/// probe for fleet validation and benchmark attribution.
async fn version_handler() -> Json<cortex_core::build_info::BuildInfo> {
Json(crate::version::build_info())
}
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
@@ -65,13 +36,7 @@ async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<Discov
}
async fn health_handler(State(state): State<Arc<NeuronState>>) -> Json<HealthResponse> {
// HealthCache owns the uptime + per-device readings; the activation
// tracker owns the pre-warm progress. We compose the response here
// so the cache stays a thin runtime-state cache and doesn't need to
// know about activation lifecycle.
let mut snapshot = state.health_cache.snapshot().await;
snapshot.activation = state.activation.snapshot().await;
Json(snapshot)
Json(state.health_cache.snapshot().await)
}
async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse {
@@ -80,7 +45,7 @@ async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse
Ok(models) => Json(json!(models)).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
@@ -90,70 +55,14 @@ async fn load_model(
State(state): State<Arc<NeuronState>>,
Json(spec): Json<ModelSpec>,
) -> impl IntoResponse {
// Driver/library mismatch preflight (#19): every CUDA load is
// guaranteed to fail until the host reboots. Reject up front with
// the operator-actionable reason instead of letting the load die
// minutes later inside cuInit/NCCL with a cryptic error.
if let Some(reason) = &state.discovery.cuda_unavailable_reason {
tracing::warn!(model = %spec.model_id, reason = %reason, "load_model rejected: CUDA unavailable");
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": reason,
"code": "cuda_unavailable",
})),
)
.into_response();
}
let registry = state.registry.read().await;
match registry.load_model(&spec).await {
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
Err(e) => {
// If the underlying failure is a structured preflight
// rejection, surface it as 422 Unprocessable Entity with
// the typed JSON body. The kind/model_id/suggestion/etc.
// fields let cortex (and operators reading the response
// directly) act on the failure without parsing free text.
if let Some(pf) = e.downcast_ref::<PreflightError>() {
tracing::warn!(
model = %spec.model_id,
reason = preflight_kind(pf),
detail = %pf,
"load_model rejected by preflight"
);
return (
StatusCode::UNPROCESSABLE_ENTITY,
Json(json!({ "error": pf })),
)
.into_response();
}
// Log the full anyhow chain server-side so journalctl shows
// the underlying failure (hf-hub timeout, permission denied,
// disk full, etc.) without needing to inspect the HTTP
// response body separately.
tracing::warn!(
model = %spec.model_id,
error = %format!("{e:#}"),
"load_model failed"
);
(
StatusCode::BAD_REQUEST,
Json(json!({"error": format!("{e:#}")})),
)
.into_response()
}
}
}
/// Short kebab-case tag for a preflight failure, used as a structured
/// log field for journalctl-side filtering. Mirrors the same helper in
/// `startup.rs`; duplicated to keep the module surfaces independent.
fn preflight_kind(err: &PreflightError) -> &'static str {
match err {
PreflightError::RepoFetchFailed { .. } => "repo_fetch_failed",
PreflightError::EmptyRepo { .. } => "empty_repo",
PreflightError::TpRequiresSafetensors { .. } => "tp_requires_safetensors",
PreflightError::QuantNotFound { .. } => "quant_not_found",
Err(e) => (
StatusCode::BAD_REQUEST,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
@@ -175,11 +84,7 @@ async fn unload_model(
let registry = state.registry.read().await;
match registry.unload_model(&model_id).await {
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
Err(e) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
}
}
@@ -197,347 +102,3 @@ async fn model_endpoint(
.into_response(),
}
}
/// OpenAI-compatible chat completions. Dispatches to streaming SSE when
/// `stream: true` is set on the request; otherwise returns a single
/// `ChatCompletionResponse`.
async fn chat_completions(
State(state): State<Arc<NeuronState>>,
headers: axum::http::HeaderMap,
Json(req): Json<ChatCompletionRequest>,
) -> impl IntoResponse {
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"error": "candle harness not enabled on this neuron"})),
)
.into_response();
};
// Reasoning-content opt-in. Off by default → naïve clients
// (Zed's commit-message generator, vanilla OpenAI clients)
// never see `<think>` blocks. On when the caller sends
// `x-include-thinking: true` (helexa-acp does this so its
// own ThinkParser keeps working unchanged).
let include_thinking = headers
.get("x-include-thinking")
.and_then(|v| v.to_str().ok())
.map(|s| matches!(s.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
.unwrap_or(false);
let chat_config = openai_chat::ChatProjectionConfig {
include_thinking,
reasoning_markers: None, // filled in from the loaded model inside candle
};
if req.stream.unwrap_or(false) {
match candle.chat_completion_stream_with(req, chat_config).await {
Ok(rx) => {
// Each chunk → one SSE `data: {json}` line. After the
// channel closes, append the OpenAI [DONE] terminator.
let body_stream = ReceiverStream::new(rx).map(|chunk| {
let body = serde_json::to_string(&chunk).unwrap_or_default();
Ok::<_, Infallible>(Event::default().data(body))
});
let done_stream =
stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
Sse::new(body_stream.chain(done_stream))
.keep_alive(KeepAlive::default())
.into_response()
}
Err(InferenceError::ModelNotLoaded(id)) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
"code": "prompt_too_long",
"prompt_len": prompt_len,
"max": max,
})),
)
.into_response(),
Err(InferenceError::InsufficientVram {
free_mb,
required_mb,
}) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": format!(
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
),
"code": "insufficient_vram",
"free_mb": free_mb,
"required_mb": required_mb,
})),
)
.into_response(),
Err(InferenceError::VisionUnsupported { model_id }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"model '{model_id}' does not support image input"
),
"code": "vision_unsupported",
"model_id": model_id,
"suggestion": "load a vision-capable model or remove image_url content parts",
})),
)
.into_response(),
Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
}
} else {
match candle.chat_completion(req).await {
Ok(resp) => Json(resp).into_response(),
Err(InferenceError::ModelNotLoaded(id)) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
"code": "prompt_too_long",
"prompt_len": prompt_len,
"max": max,
})),
)
.into_response(),
Err(InferenceError::InsufficientVram {
free_mb,
required_mb,
}) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": format!(
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
),
"code": "insufficient_vram",
"free_mb": free_mb,
"required_mb": required_mb,
})),
)
.into_response(),
Err(InferenceError::VisionUnsupported { model_id }) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"model '{model_id}' does not support image input"
),
"code": "vision_unsupported",
"model_id": model_id,
"suggestion": "load a vision-capable model or remove image_url content parts",
})),
)
.into_response(),
Err(InferenceError::Other(e)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
}
}
}
/// OpenAI Responses API (`POST /v1/responses`). Translates the
/// Responses-shaped request into a chat-completions one the candle
/// harness already understands, then re-projects the harness's
/// event stream into the Responses event family.
async fn responses(
State(state): State<Arc<NeuronState>>,
Json(req): Json<ResponsesRequest>,
) -> impl IntoResponse {
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"error": "candle harness not enabled on this neuron"})),
)
.into_response();
};
let stream_requested = req.stream;
let model_id = req.model.clone();
let response_id = mint_response_id();
let message_item_id = mint_message_item_id();
// Translate Responses → chat completions. The only failure
// mode today is `previous_response_id` set, which we reject
// with 400 — stateful conversations need a persistence layer
// we haven't built.
let mut chat_req = match openai_responses::request_to_chat(req) {
Ok(r) => r,
Err(openai_responses::TranslateError::ChainedConversationNotSupported) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "previous_response_id is not supported on this neuron",
"code": "chained_conversation_not_supported"
})),
)
.into_response();
}
};
chat_req.stream = Some(stream_requested);
if stream_requested {
match candle
.responses_stream(chat_req, response_id, message_item_id)
.await
{
Ok(rx) => {
// Each ResponseStreamFrame → one SSE event carrying
// both an event name and JSON data. The Responses
// API doesn't use a `[DONE]` terminator — clients
// see the `response.completed` event as the end of
// the stream.
let body_stream = ReceiverStream::new(rx).map(|frame| {
let body = serde_json::to_string(&frame.data).unwrap_or_else(|_| "{}".into());
Ok::<_, Infallible>(Event::default().event(frame.event_name).data(body))
});
Sse::new(body_stream)
.keep_alive(KeepAlive::default())
.into_response()
}
Err(e) => inference_error_response(e),
}
} else {
// Non-streaming: drive the existing chat completion path
// and translate the result. We don't currently re-tokenise
// to compute usage; the harness returns it via the chat
// response and we pass it through.
match candle.chat_completion(chat_req).await {
Ok(chat_resp) => {
// Extract the assistant text (chat completions
// always emits one choice on the candle path).
let text = chat_resp
.choices
.first()
.map(|c| match &c.message.content {
MessageContent::Text(t) => t.clone(),
MessageContent::Parts(_) => {
// Candle output is always text today;
// a Parts response would be surprising.
// Empty-string fallback is safer than
// a panic.
String::new()
}
})
.unwrap_or_default();
let finish = chat_resp
.choices
.first()
.and_then(|c| c.finish_reason.as_deref())
.map(finish_reason_from_str)
.unwrap_or(crate::wire::FinishReason::Stop);
let usage = chat_resp.usage.as_ref().map(|u| ResponsesUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
total_tokens: u.prompt_tokens + u.completion_tokens,
});
let meta = openai_responses::ResponseMeta {
response_id: mint_response_id(),
created_at: unix_now_secs(),
model_id,
message_item_id: mint_message_item_id(),
};
let _ = chat_resp; // make the borrow-checker happy if `text` consumed it
let resp = openai_responses::build_response(&meta, text, finish, usage);
Json(resp).into_response()
}
Err(e) => inference_error_response(e),
}
}
}
fn finish_reason_from_str(s: &str) -> crate::wire::FinishReason {
use crate::wire::FinishReason;
match s {
"length" => FinishReason::Length,
"tool_calls" => FinishReason::ToolCalls,
_ => FinishReason::Stop,
}
}
/// Centralised mapping from [`InferenceError`] to an HTTP response.
/// Lifted out so the chat-completions and responses handlers stay
/// readable and changes to error-code semantics happen in one spot.
fn inference_error_response(err: InferenceError) -> axum::response::Response {
match err {
InferenceError::ModelNotLoaded(id) => (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
)
.into_response(),
InferenceError::PromptTooLong { prompt_len, max } => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
"code": "prompt_too_long",
"prompt_len": prompt_len,
"max": max,
})),
)
.into_response(),
InferenceError::InsufficientVram {
free_mb,
required_mb,
} => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": format!(
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
),
"code": "insufficient_vram",
"free_mb": free_mb,
"required_mb": required_mb,
})),
)
.into_response(),
InferenceError::VisionUnsupported { model_id } => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"model '{model_id}' does not support image input"
),
"code": "vision_unsupported",
"model_id": model_id,
"suggestion": "load a vision-capable model or remove image_url content parts",
})),
)
.into_response(),
InferenceError::Other(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("{e:#}")})),
)
.into_response(),
}
}
fn mint_response_id() -> String {
format!("resp_{:x}", unix_subsec_nanos())
}
fn mint_message_item_id() -> String {
format!("msg_{:x}", unix_subsec_nanos())
}
fn unix_now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn unix_subsec_nanos() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0)
}

View File

@@ -1,22 +1,12 @@
//! Neuron configuration loaded from neuron.toml.
use cortex_core::harness::{HarnessConfig, ModelSpec};
use cortex_core::harness::HarnessConfig;
use figment::{
Figment,
providers::{Env, Format, Toml},
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
/// Default scheme name applied to bare `org/name` model ids when no
/// `[harness.candle.default_source]` is set. Keeps existing operator
/// configs (which know nothing about schemes) working unchanged.
pub const DEFAULT_SOURCE_SCHEME: &str = "huggingface";
/// Endpoint URL for the default huggingface source, used when no
/// `[harness.candle.sources.huggingface]` is configured.
pub const DEFAULT_HF_ENDPOINT: &str = "https://huggingface.co";
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NeuronConfig {
@@ -24,156 +14,6 @@ pub struct NeuronConfig {
pub port: u16,
#[serde(default)]
pub harnesses: Vec<HarnessConfig>,
/// Per-harness configuration. Currently only `candle` is recognised.
#[serde(default)]
pub harness: HarnessSettings,
/// Models to auto-load when the neuron service activates. Each entry
/// is loaded sequentially before the HTTP listener binds. A failure
/// on any single entry logs a warning and proceeds — broken entries
/// don't prevent the rest of the fleet from starting.
#[serde(default)]
pub default_models: Vec<ModelSpec>,
}
/// Settings for individual harness implementations. Each harness owns
/// its own sub-table so users only configure the harnesses they enable.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HarnessSettings {
#[serde(default)]
pub candle: CandleHarnessConfig,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CandleHarnessConfig {
/// HuggingFace cache directory for model weights.
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
///
/// Retained for back-compat — operators with existing
/// `hf_cache = "..."` configs continue to work. Treated as the
/// `huggingface` source's cache_dir when a sources table isn't
/// provided.
#[serde(default)]
pub hf_cache: Option<PathBuf>,
/// Default source scheme applied to bare `org/name` model ids
/// (those without an explicit `scheme:` prefix). When unset, falls
/// back to `DEFAULT_SOURCE_SCHEME` ("huggingface").
#[serde(default)]
pub default_source: Option<String>,
/// Per-scheme source endpoints. Each entry maps a scheme name
/// (`huggingface`, `helexa`, an operator's mirror tag, …) to its
/// endpoint URL, optional auth env var, and optional cache
/// directory.
///
/// When absent or missing the `huggingface` key, the loader
/// synthesises a `huggingface` entry pointing at
/// `https://huggingface.co` with `hf_cache` (above) as its
/// cache_dir. This keeps single-source configs ergonomic.
#[serde(default)]
pub sources: HashMap<String, SourceConfig>,
/// Prefix KV cache across requests (#11). Applies per loaded
/// model, on architectures that support cache snapshots (qwen3_5).
#[serde(default)]
pub prefix_cache: PrefixCacheConfig,
}
/// `[harness.candle.prefix_cache]` settings.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefixCacheConfig {
/// Master switch. On by default — set `false` to restore the
/// clear-every-request behaviour.
#[serde(default = "default_prefix_cache_enabled")]
pub enabled: bool,
/// Snapshot byte budget per loaded model, in MiB. Snapshots live
/// on the model's device, so this comes out of the same VRAM that
/// serves inference — size it against the device's headroom after
/// the model weights.
#[serde(default = "default_prefix_cache_budget_mb")]
pub budget_mb: u64,
/// Maximum live snapshots per loaded model, regardless of budget.
#[serde(default = "default_prefix_cache_max_entries")]
pub max_entries: usize,
}
impl Default for PrefixCacheConfig {
fn default() -> Self {
Self {
enabled: default_prefix_cache_enabled(),
budget_mb: default_prefix_cache_budget_mb(),
max_entries: default_prefix_cache_max_entries(),
}
}
}
fn default_prefix_cache_enabled() -> bool {
true
}
fn default_prefix_cache_budget_mb() -> u64 {
1024
}
fn default_prefix_cache_max_entries() -> usize {
8
}
/// Per-scheme source configuration. Mirrors the shape `hf_hub::ApiBuilder`
/// needs: endpoint URL, optional auth token (read from an env var so
/// secrets stay out of the config file), and optional cache directory
/// disambiguated per source to prevent mirror-vs-canonical collisions.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SourceConfig {
/// Base URL of the registry. Must speak the HF-compatible wire
/// format (siblings listing at
/// `/api/models/{org}/{name}[/revision/{rev}]`, blob fetch at
/// `/{org}/{name}/resolve/{rev}/{path}`).
pub endpoint: String,
/// Environment variable name to read for the bearer token used
/// against this source. `None` = anonymous. Reading from env
/// (vs. literal token in the config) keeps secrets out of TOML.
#[serde(default)]
pub auth_env: Option<String>,
/// Cache directory for this source. The hf-hub
/// `models--{org}--{name}/snapshots/...` tree lives directly
/// under this path, so distinct sources serving the same
/// `org/name` cannot collide on disk.
///
/// `None` means "share the harness `hf_cache` directory" — only
/// safe when the operator has exactly one source configured.
#[serde(default)]
pub cache_dir: Option<PathBuf>,
}
impl CandleHarnessConfig {
/// Resolve the effective sources map for this config, synthesising
/// a `huggingface` entry from legacy fields (`hf_cache`) when the
/// operator hasn't supplied a sources table. Idempotent.
///
/// Returns a fresh map rather than mutating self so the original
/// (operator-typed) config can still be serialized back to TOML
/// for diagnostics.
pub fn effective_sources(&self) -> HashMap<String, SourceConfig> {
let mut out = self.sources.clone();
out.entry(DEFAULT_SOURCE_SCHEME.to_string())
.or_insert_with(|| SourceConfig {
endpoint: DEFAULT_HF_ENDPOINT.to_string(),
auth_env: Some("HF_TOKEN".to_string()),
cache_dir: self.hf_cache.clone(),
});
out
}
/// Effective default scheme. Falls back to `DEFAULT_SOURCE_SCHEME`
/// when the operator hasn't pinned one.
pub fn effective_default_source(&self) -> &str {
self.default_source
.as_deref()
.unwrap_or(DEFAULT_SOURCE_SCHEME)
}
}
fn default_port() -> u16 {
@@ -195,114 +35,6 @@ impl Default for NeuronConfig {
Self {
port: 13131,
harnesses: vec![],
harness: HarnessSettings::default(),
default_models: vec![],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn effective_sources_synthesises_huggingface_when_absent() {
let cfg = CandleHarnessConfig::default();
let sources = cfg.effective_sources();
assert!(sources.contains_key("huggingface"));
let hf = &sources["huggingface"];
assert_eq!(hf.endpoint, DEFAULT_HF_ENDPOINT);
assert_eq!(hf.auth_env.as_deref(), Some("HF_TOKEN"));
assert!(hf.cache_dir.is_none());
}
#[test]
fn effective_sources_carries_legacy_hf_cache_into_synth_entry() {
// Existing operator configs only set `hf_cache = "/archive3/..."`
// — the synth must pick that up so the loader keeps using the
// operator's storage.
let cfg = CandleHarnessConfig {
hf_cache: Some(PathBuf::from("/archive3/llm-cache")),
..Default::default()
};
let sources = cfg.effective_sources();
assert_eq!(
sources["huggingface"].cache_dir.as_deref(),
Some(Path::new("/archive3/llm-cache"))
);
}
#[test]
fn effective_sources_preserves_explicit_huggingface_entry() {
// When an operator types out `[harness.candle.sources.huggingface]`
// explicitly, we must not clobber it with the synth defaults.
let mut sources = HashMap::new();
sources.insert(
"huggingface".to_string(),
SourceConfig {
endpoint: "https://huggingface.example.org".into(),
auth_env: Some("MY_TOKEN".into()),
cache_dir: Some(PathBuf::from("/operator-cache")),
},
);
let cfg = CandleHarnessConfig {
hf_cache: Some(PathBuf::from("/legacy-cache")),
sources,
..Default::default()
};
let effective = cfg.effective_sources();
assert_eq!(
effective["huggingface"].endpoint,
"https://huggingface.example.org"
);
assert_eq!(
effective["huggingface"].auth_env.as_deref(),
Some("MY_TOKEN")
);
assert_eq!(
effective["huggingface"].cache_dir.as_deref(),
Some(Path::new("/operator-cache"))
);
}
#[test]
fn effective_sources_includes_helexa_alongside_synth_huggingface() {
let mut sources = HashMap::new();
sources.insert(
"helexa".to_string(),
SourceConfig {
endpoint: "https://registry.helexa.ai".into(),
auth_env: Some("HELEXA_TOKEN".into()),
cache_dir: Some(PathBuf::from("/archive3/llm-cache/helexa")),
},
);
let cfg = CandleHarnessConfig {
hf_cache: Some(PathBuf::from("/archive3/llm-cache/huggingface")),
sources,
..Default::default()
};
let effective = cfg.effective_sources();
assert_eq!(effective.len(), 2);
assert_eq!(effective["helexa"].endpoint, "https://registry.helexa.ai");
// huggingface still gets synth-derived from legacy hf_cache.
assert_eq!(
effective["huggingface"].cache_dir.as_deref(),
Some(Path::new("/archive3/llm-cache/huggingface"))
);
}
#[test]
fn effective_default_source_falls_back() {
let cfg = CandleHarnessConfig::default();
assert_eq!(cfg.effective_default_source(), DEFAULT_SOURCE_SCHEME);
}
#[test]
fn effective_default_source_honours_explicit() {
let cfg = CandleHarnessConfig {
default_source: Some("helexa".into()),
..Default::default()
};
assert_eq!(cfg.effective_default_source(), "helexa");
}
}

View File

@@ -1,84 +0,0 @@
//! FFI declarations for the CUDA kernels in `gdn.cu`.
//!
//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs`
//! covering only the Gated DeltaNet kernels we currently use. Other
//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective
//! scan, etc.) would land here too as we absorb them.
//!
//! All function declarations are MIT-licensed from upstream and
//! unchanged apart from this header.
use std::ffi::c_void;
#[allow(dead_code)]
unsafe extern "C" {
// GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next.
pub(crate) fn gated_delta_rule_recurrence(
q: *const f32,
k: *const f32,
v: *const f32,
g: *const f32,
beta: *const f32,
state: *mut f32,
output: *mut f32,
bh: i32,
seq_len: i32,
k_dim: i32,
v_dim: i32,
stream: i64,
);
/// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks).
pub(crate) fn chunked_gated_delta_rule_recurrence(
q: *const f32,
k: *const f32,
v: *const f32,
g: *const f32,
beta: *const f32,
state: *mut f32,
output: *mut f32,
bh: i32,
seq_len: i32,
k_dim: i32,
v_dim: i32,
stream: i64,
);
pub(crate) fn causal_conv1d_update(
x: *const c_void,
weight: *const c_void,
conv_state: *mut c_void,
output: *mut c_void,
batch_size: i32,
conv_dim: i32,
kernel_size: i32,
dtype: i32,
stream: i64,
);
pub(crate) fn causal_conv1d_full(
x: *const c_void,
weight: *const c_void,
conv_state_out: *mut c_void,
output: *mut c_void,
batch_size: i32,
conv_dim: i32,
seq_len: i32,
kernel_size: i32,
dtype: i32,
stream: i64,
);
pub(crate) fn fused_gdn_gating(
b: *const c_void,
a: *const c_void,
a_log: *const f32,
dt_bias: *const f32,
beta_out: *mut c_void,
g_out: *mut c_void,
total_elements: i32,
num_heads: i32,
dtype: i32,
stream: i64,
);
}

View File

@@ -1,711 +0,0 @@
// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`).
//
// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this
// file are limited to this banner; the kernels are unchanged so a
// diff against upstream stays minimal.
//
// Five kernels exposed via `extern "C"` shims at the bottom:
// - gated_delta_rule_recurrence (per-token decode)
// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill)
// - causal_conv1d_update (single-token conv decode)
// - causal_conv1d_full (multi-token conv prefill)
// - fused_gdn_gating (beta = sigmoid(b);
// g = -exp(A_log) * softplus(a + dt_bias))
#include "cuda_bf16.h"
#include "cuda_fp16.h"
#include <cmath>
#include <cstdint>
#include <cuda_runtime.h>
// ============================================================================
// Kernel 1: gated_delta_rule_recurrence (optimized)
//
// V-tiled recurrence with compile-time K dimension for register residency.
// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of
// state. Shared memory holds k_buf and q_buf (2*BK floats).
//
// Optimizations over naive version:
// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30)
// - #pragma unroll on all k-loops -> full ILP
// - Fused decay+kv_mem pass and fused state_update+output pass
// - __fmaf_rn intrinsics for guaranteed fused multiply-add
// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere
//
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
// state: [BH, K, V] (in/out) output: [BH, S, V]
// ============================================================================
// Optimized kernel: BK known at compile time -> registers + full unrolling
template <int BK, int BV>
__global__ void gated_delta_rule_recurrence_kernel_tiled(
const float *__restrict__ q, // [BH, S, K]
const float *__restrict__ k, // [BH, S, K]
const float *__restrict__ v, // [BH, S, V]
const float *__restrict__ g, // [BH, S]
const float *__restrict__ beta, // [BH, S]
float *__restrict__ state, // [BH, K, V]
float *__restrict__ output, // [BH, S, V]
int seq_len, int v_dim) {
const int v_tile = blockIdx.x; // which V-tile
const int bh = blockIdx.y; // batch*head index
const int tid = threadIdx.x; // thread within tile [0, BV)
const int v_idx = v_tile * BV + tid; // global V index
if (v_idx >= v_dim)
return;
// Pointers for this (batch, head)
const float *q_bh = q + bh * seq_len * BK;
const float *k_bh = k + bh * seq_len * BK;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * BK * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
// Shared memory: k_buf[BK] + q_buf[BK]
__shared__ float k_buf[BK];
__shared__ float q_buf[BK];
// Load state column into registers — BK is compile-time, so this is
// a true register array (not spilled to local memory)
float s[BK];
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
for (int t = 0; t < seq_len; t++) {
// Collaboratively load k_t into shared memory
// BK / BV loads per thread (e.g. 128/64 = 2)
#pragma unroll
for (int j = tid; j < BK; j += BV) {
k_buf[j] = k_bh[t * BK + j];
}
__syncthreads();
// Load scalars for this timestep
float decay = expf(g_bh[t]);
float beta_t = beta_bh[t];
float v_t = v_bh[t * v_dim + v_idx];
// Fused pass 1: decay state + compute kv_mem
float kv_mem = 0.0f;
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] *= decay;
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
}
// Delta rule
float delta = (v_t - kv_mem) * beta_t;
// Collaboratively load q_t into shared memory
#pragma unroll
for (int j = tid; j < BK; j += BV) {
q_buf[j] = q_bh[t * BK + j];
}
__syncthreads();
// Fused pass 2: update state + compute output
float y_t = 0.0f;
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
}
out_bh[t * v_dim + v_idx] = y_t;
__syncthreads();
}
// Write state back
#pragma unroll
for (int j = 0; j < BK; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
// Fallback kernel: runtime k_dim, still V-tiled for occupancy
template <int BV, int MAX_K>
__global__ void gated_delta_rule_recurrence_kernel_fallback(
const float *__restrict__ q, const float *__restrict__ k,
const float *__restrict__ v, const float *__restrict__ g,
const float *__restrict__ beta, float *__restrict__ state,
float *__restrict__ output, int seq_len, int k_dim, int v_dim) {
const int v_tile = blockIdx.x;
const int bh = blockIdx.y;
const int tid = threadIdx.x;
const int v_idx = v_tile * BV + tid;
if (v_idx >= v_dim)
return;
const float *q_bh = q + bh * seq_len * k_dim;
const float *k_bh = k + bh * seq_len * k_dim;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * k_dim * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
extern __shared__ float shared[];
float *k_buf = shared;
float *q_buf = shared + k_dim;
float s[MAX_K];
for (int j = 0; j < k_dim; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
for (int t = 0; t < seq_len; t++) {
for (int j = tid; j < k_dim; j += BV) {
k_buf[j] = k_bh[t * k_dim + j];
}
__syncthreads();
float decay = expf(g_bh[t]);
float beta_t = beta_bh[t];
float v_t = v_bh[t * v_dim + v_idx];
float kv_mem = 0.0f;
for (int j = 0; j < k_dim; j++) {
s[j] *= decay;
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
}
float delta = (v_t - kv_mem) * beta_t;
for (int j = tid; j < k_dim; j += BV) {
q_buf[j] = q_bh[t * k_dim + j];
}
__syncthreads();
float y_t = 0.0f;
for (int j = 0; j < k_dim; j++) {
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
}
out_bh[t * v_dim + v_idx] = y_t;
__syncthreads();
}
for (int j = 0; j < k_dim; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
extern "C" void gated_delta_rule_recurrence(const float *q, const float *k,
const float *v, const float *g,
const float *beta, float *state,
float *output, int bh, int seq_len,
int k_dim, int v_dim,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
if (k_dim == 128) {
// Fast path for Qwen3-Next (k_dim=128)
constexpr int BK = 128;
constexpr int BV = 64;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
v_dim);
} else if (k_dim == 64) {
// Fast path for models with k_dim=64
constexpr int BK = 64;
constexpr int BV = 64;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
v_dim);
} else {
// Fallback for other k_dim values (runtime loop, still V-tiled)
constexpr int BV = 64;
constexpr int MAX_K = 256;
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
size_t smem = 2 * k_dim * sizeof(float);
gated_delta_rule_recurrence_kernel_fallback<BV, MAX_K>
<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, k_dim, v_dim);
}
}
// ============================================================================
// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization)
//
// Processes prefill tokens in BT-token chunks instead of one at a time.
// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation,
// forward substitution (triangular solve), output computation, and state
// update.
//
// Same thread model as Kernel 1: one block per (v_tile, batch*head),
// one thread per V-column. Each thread owns BK registers of state.
//
// Shared memory holds:
// k_chunk[BT * BK] -- key vectors for current chunk
// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix
// gcum[BT] -- cumulative sum of g within chunk
// beta_s[BT] -- beta values for chunk
// q_buf[BK] -- q vector (loaded one row at a time)
//
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
// state: [BH, K, V] (in/out) output: [BH, S, V]
// ============================================================================
template <int BT, int BK, int BV>
__global__ void
chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K]
const float *__restrict__ k, // [BH, S, K]
const float *__restrict__ v, // [BH, S, V]
const float *__restrict__ g, // [BH, S]
const float *__restrict__ beta, // [BH, S]
float *__restrict__ state, // [BH, K, V]
float *__restrict__ output, // [BH, S, V]
int seq_len, int v_dim) {
const int v_tile = blockIdx.x;
const int bh = blockIdx.y;
const int tid = threadIdx.x;
const int v_idx = v_tile * BV + tid;
if (v_idx >= v_dim)
return;
const int num_chunks = (seq_len + BT - 1) / BT;
// Pointers for this (batch, head)
const float *q_bh = q + bh * seq_len * BK;
const float *k_bh = k + bh * seq_len * BK;
const float *v_bh = v + bh * seq_len * v_dim;
const float *g_bh = g + bh * seq_len;
const float *beta_bh = beta + bh * seq_len;
float *state_bh = state + bh * BK * v_dim;
float *out_bh = output + bh * seq_len * v_dim;
// Dynamic shared memory layout
extern __shared__ float smem[];
float *k_chunk = smem; // [BT * BK]
float *kk_dot = smem + BT * BK; // [BT * BT]
float *gcum = smem + BT * BK + BT * BT; // [BT]
float *beta_s = gcum + BT; // [BT]
float *q_buf = beta_s + BT; // [BK]
// Load state column into registers
float s[BK];
#pragma unroll
for (int j = 0; j < BK; j++) {
s[j] = state_bh[j * v_dim + v_idx];
}
// Per-thread register array for corrected deltas
float delta[BT];
for (int c = 0; c < num_chunks; c++) {
const int chunk_start = c * BT;
const int chunk_len = min(BT, seq_len - chunk_start);
// === Phase 1: Cooperative load of k, beta, g into shared memory ===
for (int t = 0; t < chunk_len; t++) {
for (int j = tid; j < BK; j += BV) {
k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j];
}
}
if (tid < chunk_len) {
beta_s[tid] = beta_bh[chunk_start + tid];
gcum[tid] = g_bh[chunk_start + tid];
}
__syncthreads();
// === Phase 1b: Parallel prefix sum of g (Hillis-Steele) ===
for (int stride = 1; stride < BT; stride <<= 1) {
float prev = 0.0f;
if (tid < chunk_len && (int)tid >= stride)
prev = gcum[tid - stride];
__syncthreads();
if (tid < chunk_len && (int)tid >= stride)
gcum[tid] += prev;
__syncthreads();
}
// === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i ===
// Only lower-triangular entries needed (strictly lower)
for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) {
int i = idx / chunk_len;
int j = idx % chunk_len;
if (j < i) {
float dot = 0.0f;
for (int d = 0; d < BK; d++) {
dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot);
}
kk_dot[i * BT + j] = dot;
}
}
__syncthreads();
// === Phase 3: Forward substitution (per V-column, in registers) ===
// Computes corrected delta values via triangular solve
for (int i = 0; i < chunk_len; i++) {
float v_i = v_bh[(chunk_start + i) * v_dim + v_idx];
float decay_i = expf(gcum[i]);
float beta_i = beta_s[i];
// Inter-chunk contribution: state @ k[i] with decay
float kv_mem = 0.0f;
#pragma unroll
for (int d = 0; d < BK; d++) {
kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem);
}
float rhs = beta_i * (v_i - kv_mem);
// Subtract lower-triangular contributions (intra-chunk)
for (int j = 0; j < i; j++) {
float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]);
rhs -= a_ij * delta[j];
}
delta[i] = rhs;
}
// === Phase 4: Output computation (per V-column) ===
for (int i = 0; i < chunk_len; i++) {
// Cooperatively load q[i] into shared
for (int j = tid; j < BK; j += BV) {
q_buf[j] = q_bh[(chunk_start + i) * BK + j];
}
__syncthreads();
float decay_i = expf(gcum[i]);
// Inter-chunk contribution: q[i] @ (state * decay)
float o_val = 0.0f;
#pragma unroll
for (int d = 0; d < BK; d++) {
o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val);
}
// Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] *
// exp(gcum[i] - gcum[j])
for (int j = 0; j <= i; j++) {
float qk_dot = 0.0f;
for (int d = 0; d < BK; d++) {
qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot);
}
o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]);
}
out_bh[(chunk_start + i) * v_dim + v_idx] = o_val;
__syncthreads();
}
// === Phase 5: State update for next chunk ===
float g_total = gcum[chunk_len - 1];
#pragma unroll
for (int d = 0; d < BK; d++) {
float s_new = s[d] * expf(g_total);
for (int t = 0; t < chunk_len; t++) {
s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]);
}
s[d] = s_new;
}
__syncthreads();
}
// Write final state back
#pragma unroll
for (int j = 0; j < BK; j++) {
state_bh[j * v_dim + v_idx] = s[j];
}
}
extern "C" void chunked_gated_delta_rule_recurrence(
const float *q, const float *k, const float *v, const float *g,
const float *beta, float *state, float *output, int bh, int seq_len,
int k_dim, int v_dim, int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
if (k_dim == 128) {
constexpr int BT = 64;
constexpr int BK = 128;
constexpr int BV = 64;
// Shared memory: BT*BK + BT*BT + BT + BT + BK floats
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
// Request extended shared memory
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem);
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, v_dim);
} else if (k_dim == 64) {
constexpr int BT = 64;
constexpr int BK = 64;
constexpr int BV = 64;
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem);
dim3 grid((v_dim + BV - 1) / BV, bh);
dim3 block(BV);
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
seq_len, v_dim);
} else {
// Fallback: use the sequential kernel for unsupported k_dim
gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len,
k_dim, v_dim, stream);
}
}
// ============================================================================
// Kernel 2a: causal_conv1d_update (decode path, single step)
//
// Each thread handles one channel: shift conv_state left by 1,
// insert new value, dot product with weight, apply SiLU.
//
// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
// conv_state: [B, conv_dim, kernel_size] (in/out)
// output: [B, conv_dim, 1]
// ============================================================================
template <typename T>
__global__ void causal_conv1d_update_kernel(
const T *__restrict__ x, // [B, conv_dim, 1]
const T *__restrict__ weight, // [conv_dim, kernel_size]
T *__restrict__ conv_state, // [B, conv_dim, kernel_size]
T *__restrict__ output, // [B, conv_dim, 1]
int batch_size, int conv_dim, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int b = blockIdx.y;
if (ch >= conv_dim || b >= batch_size)
return;
// Pointer to this batch/channel's conv state
T *cs = conv_state + (b * conv_dim + ch) * kernel_size;
const T *w = weight + ch * kernel_size;
// Shift state left by 1
for (int i = 0; i < kernel_size - 1; i++) {
cs[i] = cs[i + 1];
}
// Insert new value
cs[kernel_size - 1] = x[b * conv_dim + ch];
// Dot product with weight
float acc = 0.0f;
for (int i = 0; i < kernel_size; i++) {
acc += (float)cs[i] * (float)w[i];
}
// SiLU activation: x * sigmoid(x)
float sig = 1.0f / (1.0f + expf(-acc));
float result = acc * sig;
output[b * conv_dim + ch] = (T)result;
}
extern "C" void causal_conv1d_update(const void *x, const void *weight,
void *conv_state, void *output,
int batch_size, int conv_dim,
int kernel_size, int dtype,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
dim3 block(256);
dim3 grid((conv_dim + 255) / 256, batch_size);
if (dtype == 0) {
// f16
causal_conv1d_update_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)x, (const __half *)weight, (__half *)conv_state,
(__half *)output, batch_size, conv_dim, kernel_size);
} else {
// bf16
causal_conv1d_update_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
(__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size,
conv_dim, kernel_size);
}
}
// ============================================================================
// Kernel 2b: causal_conv1d_full (prefill path)
//
// Each thread handles one (channel, position): causal window with
// zero-padding, dot product with weight, SiLU.
// A second pass writes the conv_state from the last kernel_size positions.
//
// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S]
// ============================================================================
template <typename T>
__global__ void causal_conv1d_full_kernel(
const T *__restrict__ x, // [B, conv_dim, S]
const T *__restrict__ weight, // [conv_dim, kernel_size]
T *__restrict__ output, // [B, conv_dim, S]
int batch_size, int conv_dim, int seq_len, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int pos = blockIdx.y;
const int b = blockIdx.z;
if (ch >= conv_dim || pos >= seq_len || b >= batch_size)
return;
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
const T *w = weight + ch * kernel_size;
// Causal convolution: sum over kernel_size window ending at pos
float acc = 0.0f;
for (int i = 0; i < kernel_size; i++) {
int src_pos = pos - (kernel_size - 1) + i;
float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f;
acc += x_val * (float)w[i];
}
// SiLU
float sig = 1.0f / (1.0f + expf(-acc));
float result = acc * sig;
output[(b * conv_dim + ch) * seq_len + pos] = (T)result;
}
template <typename T>
__global__ void save_conv_state_kernel(
const T *__restrict__ x, // [B, conv_dim, S]
T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size]
int batch_size, int conv_dim, int seq_len, int kernel_size) {
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
const int b = blockIdx.y;
if (ch >= conv_dim || b >= batch_size)
return;
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size;
// Save last kernel_size positions (zero-pad if seq_len < kernel_size)
int pad = kernel_size - seq_len;
for (int i = 0; i < kernel_size; i++) {
if (i < pad) {
cs[i] = (T)0.0f;
} else {
cs[i] = x_bch[seq_len - kernel_size + i];
}
}
}
extern "C" void causal_conv1d_full(const void *x, const void *weight,
void *conv_state_out, void *output,
int batch_size, int conv_dim, int seq_len,
int kernel_size, int dtype, int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
// Main convolution kernel
dim3 block(256);
dim3 grid((conv_dim + 255) / 256, seq_len, batch_size);
if (dtype == 0) {
causal_conv1d_full_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)x, (const __half *)weight, (__half *)output, batch_size,
conv_dim, seq_len, kernel_size);
// Save conv state
dim3 grid2((conv_dim + 255) / 256, batch_size);
save_conv_state_kernel<__half><<<grid2, block, 0, custream>>>(
(const __half *)x, (__half *)conv_state_out, batch_size, conv_dim,
seq_len, kernel_size);
} else {
causal_conv1d_full_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
(__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size);
dim3 grid2((conv_dim + 255) / 256, batch_size);
save_conv_state_kernel<__nv_bfloat16><<<grid2, block, 0, custream>>>(
(const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size,
conv_dim, seq_len, kernel_size);
}
}
// ============================================================================
// Kernel 3: fused_gdn_gating
//
// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
// a_log and dt_bias are per-head (broadcast over batch*seq).
//
// b, a: [total] a_log, dt_bias: [num_heads]
// beta_out, g_out: [total]
// ============================================================================
template <typename T>
__global__ void
fused_gdn_gating_kernel(const T *__restrict__ b, // [total]
const T *__restrict__ a, // [total]
const float *__restrict__ a_log, // [num_heads]
const float *__restrict__ dt_bias, // [num_heads]
T *__restrict__ beta_out, // [total]
T *__restrict__ g_out, // [total]
int total_elements, int num_heads) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_elements)
return;
// Head index: elements are laid out as [..., num_heads]
int head_idx = idx % num_heads;
// beta = sigmoid(b)
float b_val = (float)b[idx];
float beta = 1.0f / (1.0f + expf(-b_val));
// g = -exp(a_log) * softplus(a + dt_bias)
float a_val = (float)a[idx];
float a_log_val = a_log[head_idx];
float dt_bias_val = dt_bias[head_idx];
float sp_input = a_val + dt_bias_val;
float softplus_val = logf(1.0f + expf(sp_input));
float g_val = -expf(a_log_val) * softplus_val;
beta_out[idx] = (T)beta;
g_out[idx] = (T)g_val;
}
extern "C" void fused_gdn_gating(const void *b, const void *a,
const float *a_log, const float *dt_bias,
void *beta_out, void *g_out,
int total_elements, int num_heads, int dtype,
int64_t stream) {
const cudaStream_t custream = (cudaStream_t)stream;
dim3 block(256);
dim3 grid((total_elements + 255) / 256);
if (dtype == 0) {
fused_gdn_gating_kernel<__half><<<grid, block, 0, custream>>>(
(const __half *)b, (const __half *)a, a_log, dt_bias,
(__half *)beta_out, (__half *)g_out, total_elements, num_heads);
} else {
fused_gdn_gating_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
(const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias,
(__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements,
num_heads);
}
}

View File

@@ -1,486 +0,0 @@
//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`.
//!
//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in
//! this file are this header comment — the FFI path module name is
//! `crate::cuda::ffi`, identical to upstream's layout.
#![allow(clippy::cast_possible_truncation)]
use candle_core::{Result, Tensor};
#[cfg(feature = "cuda")]
use candle_core::DType;
/// CUDA-accelerated gated delta rule recurrence.
///
/// Inputs (all contiguous, f32):
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
/// state: [BH, K, V] (mutated in place)
///
/// Returns: output [BH, S, V]
#[cfg(feature = "cuda")]
pub fn gated_delta_rule_recurrence_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: &mut Tensor,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
let (bh, seq_len, k_dim) = q.dims3()?;
let v_dim = v.dim(2)?;
let dev = q.device().as_cuda_device()?;
let (q_s, q_l) = q.storage_and_layout();
let q_s = match &*q_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("q must be a cuda tensor"),
};
let q_offset = q_l.start_offset();
let (k_s, k_l) = k.storage_and_layout();
let k_s = match &*k_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("k must be a cuda tensor"),
};
let k_offset = k_l.start_offset();
let (v_s, v_l) = v.storage_and_layout();
let v_s = match &*v_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("v must be a cuda tensor"),
};
let v_offset = v_l.start_offset();
let (g_s, g_l) = g.storage_and_layout();
let g_s = match &*g_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("g must be a cuda tensor"),
};
let g_offset = g_l.start_offset();
let (beta_s, beta_l) = beta.storage_and_layout();
let beta_s = match &*beta_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("beta must be a cuda tensor"),
};
let beta_offset = beta_l.start_offset();
let (state_s, state_l) = state.storage_and_layout();
let state_s = match &*state_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("state must be a cuda tensor"),
};
let state_offset = state_l.start_offset();
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::gated_delta_rule_recurrence(
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
bh as i32,
seq_len as i32,
k_dim as i32,
v_dim as i32,
stream,
);
}
// The kernel wrote state in-place via the raw pointer; rewrap
// (state tensor's underlying CudaSlice was modified directly)
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
Ok(Tensor::from((
candle::Storage::Cuda(output_storage),
(bh, seq_len, v_dim),
)))
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn gated_delta_rule_recurrence_cuda(
_q: &Tensor,
_k: &Tensor,
_v: &Tensor,
_g: &Tensor,
_beta: &Tensor,
_state: &mut Tensor,
) -> Result<Tensor> {
candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature")
}
/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization).
///
/// Processes prefill tokens in 64-token chunks instead of one at a time.
/// Same interface as `gated_delta_rule_recurrence_cuda`.
///
/// Inputs (all contiguous, f32):
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
/// state: [BH, K, V] (mutated in place)
///
/// Returns: output [BH, S, V]
#[cfg(feature = "cuda")]
pub fn chunked_gated_delta_rule_recurrence_cuda(
q: &Tensor,
k: &Tensor,
v: &Tensor,
g: &Tensor,
beta: &Tensor,
state: &mut Tensor,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
let (bh, seq_len, k_dim) = q.dims3()?;
let v_dim = v.dim(2)?;
let dev = q.device().as_cuda_device()?;
let (q_s, q_l) = q.storage_and_layout();
let q_s = match &*q_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("q must be a cuda tensor"),
};
let q_offset = q_l.start_offset();
let (k_s, k_l) = k.storage_and_layout();
let k_s = match &*k_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("k must be a cuda tensor"),
};
let k_offset = k_l.start_offset();
let (v_s, v_l) = v.storage_and_layout();
let v_s = match &*v_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("v must be a cuda tensor"),
};
let v_offset = v_l.start_offset();
let (g_s, g_l) = g.storage_and_layout();
let g_s = match &*g_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("g must be a cuda tensor"),
};
let g_offset = g_l.start_offset();
let (beta_s, beta_l) = beta.storage_and_layout();
let beta_s = match &*beta_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("beta must be a cuda tensor"),
};
let beta_offset = beta_l.start_offset();
let (state_s, state_l) = state.storage_and_layout();
let state_s = match &*state_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("state must be a cuda tensor"),
};
let state_offset = state_l.start_offset();
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::chunked_gated_delta_rule_recurrence(
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
bh as i32,
seq_len as i32,
k_dim as i32,
v_dim as i32,
stream,
);
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
Ok(Tensor::from((
candle::Storage::Cuda(output_storage),
(bh, seq_len, v_dim),
)))
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn chunked_gated_delta_rule_recurrence_cuda(
_q: &Tensor,
_k: &Tensor,
_v: &Tensor,
_g: &Tensor,
_beta: &Tensor,
_state: &mut Tensor,
) -> Result<Tensor> {
candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature")
}
/// CUDA-accelerated causal conv1d (both update and full paths).
///
/// For update (is_update=true):
/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update)
/// Returns: (output [B, conv_dim, 1], updated conv_state)
///
/// For full (is_update=false):
/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size])
#[cfg(feature = "cuda")]
pub fn causal_conv1d_cuda(
x: &Tensor,
weight: &Tensor,
conv_state: &Tensor,
kernel_size: usize,
is_update: bool,
) -> Result<(Tensor, Tensor)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use core::ffi::c_void;
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
x: &Tensor,
weight: &Tensor,
conv_state: &Tensor,
kernel_size: usize,
is_update: bool,
dtype_code: i32,
) -> Result<(Tensor, Tensor)> {
let dev = x.device().as_cuda_device()?;
let (batch_size, conv_dim, seq_len) = x.dims3()?;
let (x_s, x_l) = x.storage_and_layout();
let x_s = match &*x_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("x must be a cuda tensor"),
};
let x_offset = x_l.start_offset();
let (w_s, w_l) = weight.storage_and_layout();
let w_s = match &*w_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("weight must be a cuda tensor"),
};
let w_offset = w_l.start_offset();
let stream = dev.cuda_stream().cu_stream() as i64;
if is_update {
// Clone conv_state so the kernel can mutate it in place
let conv_state_new = conv_state.clone();
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim) }?;
// Scope the borrow of conv_state_new so we can move it later
{
let (cs_s, cs_l) = conv_state_new.storage_and_layout();
let cs_s = match &*cs_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("conv_state must be a cuda tensor"),
};
let cs_offset = cs_l.start_offset();
unsafe {
crate::cuda::ffi::causal_conv1d_update(
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void,
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
batch_size as i32,
conv_dim as i32,
kernel_size as i32,
dtype_code,
stream,
);
}
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
let output = Tensor::from((
candle::Storage::Cuda(output_storage),
(batch_size, conv_dim, 1usize),
));
Ok((output, conv_state_new))
} else {
// Full path: allocate new conv_state and output
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * seq_len) }?;
let cs_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * kernel_size) }?;
unsafe {
crate::cuda::ffi::causal_conv1d_full(
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void,
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
batch_size as i32,
conv_dim as i32,
seq_len as i32,
kernel_size as i32,
dtype_code,
stream,
);
}
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
let output = Tensor::from((
candle::Storage::Cuda(output_storage),
(batch_size, conv_dim, seq_len),
));
let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone());
let new_conv_state = Tensor::from((
candle::Storage::Cuda(cs_storage),
(batch_size, conv_dim, kernel_size),
));
Ok((output, new_conv_state))
}
}
match x.dtype() {
DType::F16 => cuda_fwd::<half::f16>(x, weight, conv_state, kernel_size, is_update, 0),
DType::BF16 => cuda_fwd::<half::bf16>(x, weight, conv_state, kernel_size, is_update, 1),
other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn causal_conv1d_cuda(
_x: &Tensor,
_weight: &Tensor,
_conv_state: &Tensor,
_kernel_size: usize,
_is_update: bool,
) -> Result<(Tensor, Tensor)> {
candle_core::bail!("causal_conv1d_cuda requires the cuda feature")
}
/// CUDA-accelerated fused GDN gating computation.
///
/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
///
/// b, a: [total_elements] in f16/bf16
/// a_log, dt_bias: [num_heads] in f32
///
/// Returns: (beta, g) in original dtype
#[cfg(feature = "cuda")]
pub fn fused_gdn_gating_cuda(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
) -> Result<(Tensor, Tensor)> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use core::ffi::c_void;
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
b: &Tensor,
a: &Tensor,
a_log: &Tensor,
dt_bias: &Tensor,
dtype_code: i32,
) -> Result<(Tensor, Tensor)> {
let total_elements = b.elem_count();
let num_heads = a_log.elem_count();
let shape = b.shape().clone();
let dev = b.device().as_cuda_device()?;
let (b_s, b_l) = b.storage_and_layout();
let b_s = match &*b_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("b must be a cuda tensor"),
};
let b_offset = b_l.start_offset();
let (a_s, a_l) = a.storage_and_layout();
let a_s = match &*a_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("a must be a cuda tensor"),
};
let a_offset = a_l.start_offset();
let (alog_s, alog_l) = a_log.storage_and_layout();
let alog_s = match &*alog_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("a_log must be a cuda tensor"),
};
let alog_offset = alog_l.start_offset();
let (dtb_s, dtb_l) = dt_bias.storage_and_layout();
let dtb_s = match &*dtb_s {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("dt_bias must be a cuda tensor"),
};
let dtb_offset = dtb_l.start_offset();
let beta_buf = unsafe { dev.alloc::<T>(total_elements) }?;
let g_buf = unsafe { dev.alloc::<T>(total_elements) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
unsafe {
crate::cuda::ffi::fused_gdn_gating(
b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void,
a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void,
alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32,
dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32,
beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void,
g_buf.device_ptr(g_buf.stream()).0 as *mut c_void,
total_elements as i32,
num_heads as i32,
dtype_code,
stream,
);
}
let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone());
let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone()));
let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone());
let g = Tensor::from((candle::Storage::Cuda(g_storage), shape));
Ok((beta, g))
}
match b.dtype() {
DType::F16 => cuda_fwd::<half::f16>(b, a, a_log, dt_bias, 0),
DType::BF16 => cuda_fwd::<half::bf16>(b, a, a_log, dt_bias, 1),
other => candle_core::bail!(
"fused_gdn_gating_cuda only supports f16/bf16, got {:?}",
other
),
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn fused_gdn_gating_cuda(
_b: &Tensor,
_a: &Tensor,
_a_log: &Tensor,
_dt_bias: &Tensor,
) -> Result<(Tensor, Tensor)> {
candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature")
}

View File

@@ -1,15 +0,0 @@
//! CUDA kernels and their Rust wrappers.
//!
//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`)
//! inference performance — the Gated DeltaNet kernels ported from
//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu`
//! file alongside this module; `build.rs` compiles them all into a
//! static lib via `cudaforge` and links it under the `cuda` feature.
//!
//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM,
//! etc.) they land here in their own `.cu` + `.rs` pairs.
#[cfg(feature = "cuda")]
pub mod ffi;
#[cfg(feature = "cuda")]
pub mod gdn;

View File

@@ -100,87 +100,6 @@ pub fn parse_health_info(csv_output: &str) -> Result<Vec<DeviceHealth>> {
Ok(devices)
}
// ── Driver/library mismatch preflight (#19) ─────────────────────────
/// Classify a failed nvidia-smi invocation: is it the classic
/// "Driver/library version mismatch" (userspace libs updated, kernel
/// module not reloaded — every CUDA call on the host is dead until a
/// reboot)? Returns the userspace NVML library version when the
/// message carries one ("NVML library version: 580.159"), or
/// `Some("unknown")` for a mismatch without a parsable version.
/// `None` for any other failure — other errors (no devices, perms)
/// are NOT the mismatch and must not trigger the loud diagnosis.
pub fn classify_driver_mismatch(combined_output: &str) -> Option<String> {
if !combined_output.contains("Driver/library version mismatch") {
return None;
}
let userspace = combined_output
.lines()
.find_map(|l| l.trim().strip_prefix("NVML library version:"))
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
.unwrap_or_else(|| "unknown".to_string());
Some(userspace)
}
/// Extract the loaded kernel module's driver version from
/// `/proc/driver/nvidia/version` contents. Typical first line:
///
/// ```text
/// NVRM version: NVIDIA UNIX Open Kernel Module for x86_64 580.159.03 Release Build (...)
/// ```
pub fn parse_kernel_module_version(proc_contents: &str) -> Option<String> {
let is_numeric = |p: &str| !p.is_empty() && p.chars().all(|c| c.is_ascii_digit());
let line = proc_contents
.lines()
.find(|l| l.starts_with("NVRM version:"))?;
line.split_whitespace()
.find(|tok| {
let mut parts = tok.split('.');
parts.next().is_some_and(is_numeric) && parts.next().is_some_and(is_numeric)
})
.map(|s| s.to_string())
}
/// Render the operator-actionable mismatch description carried in
/// `DiscoveryResponse::cuda_unavailable_reason` and logged at startup.
pub fn mismatch_reason(userspace: &str, kernel_module: Option<&str>) -> String {
format!(
"host NVIDIA driver/library mismatch (userspace NVML {userspace} vs loaded kernel \
module {}) — reboot the host to reload the kernel module; all CUDA inference is \
unavailable until then",
kernel_module.unwrap_or("unknown")
)
}
/// Outcome of an nvidia-smi invocation, distinguishing "binary not
/// present" (CPU-only host, not an error) from "present but failing"
/// (possible driver mismatch — worth classifying).
enum SmiOutcome {
Ok(String),
Failed(String),
Absent,
}
async fn run_nvidia_smi(args: &[&str]) -> SmiOutcome {
match tokio::process::Command::new("nvidia-smi")
.args(args)
.output()
.await
{
Err(_) => SmiOutcome::Absent,
Ok(out) if out.status.success() => {
SmiOutcome::Ok(String::from_utf8_lossy(&out.stdout).to_string())
}
Ok(out) => {
let mut combined = String::from_utf8_lossy(&out.stdout).to_string();
combined.push('\n');
combined.push_str(&String::from_utf8_lossy(&out.stderr));
SmiOutcome::Failed(combined)
}
}
}
// ── Command execution wrappers ──────────────────────────────────────
async fn run_command(cmd: &str, args: &[&str]) -> Result<String> {
@@ -220,42 +139,23 @@ pub async fn discover_system() -> Result<DiscoveryResponse> {
.trim()
.to_string();
let (devices, driver_version, cuda_unavailable_reason) = match run_nvidia_smi(&[
&format!("--query-gpu={NVIDIA_SMI_DISCOVERY_QUERY}"),
"--format=csv,noheader,nounits",
])
let (devices, driver_version) = match run_command_optional(
"nvidia-smi",
&[
&format!("--query-gpu={NVIDIA_SMI_DISCOVERY_QUERY}"),
"--format=csv,noheader,nounits",
],
)
.await
{
SmiOutcome::Ok(output) => {
Some(output) => {
let devs = parse_gpu_info(&output).unwrap_or_default();
let driver = parse_driver_version(&output);
(devs, driver, None)
(devs, driver)
}
SmiOutcome::Absent => {
None => {
tracing::info!("nvidia-smi not found — no GPU devices discovered");
(vec![], None, None)
}
SmiOutcome::Failed(combined) => {
// nvidia-smi exists but can't talk to the driver. The case
// worth diagnosing precisely is the userspace↔kernel-module
// version skew after an un-rebooted driver update (#19) —
// every CUDA call on the host fails until a reboot, and
// without this classification it surfaces as a cryptic
// NCCL/cuInit error deep inside the first model load.
let reason = classify_driver_mismatch(&combined).map(|userspace| {
let kmod = std::fs::read_to_string("/proc/driver/nvidia/version")
.ok()
.as_deref()
.and_then(parse_kernel_module_version);
mismatch_reason(&userspace, kmod.as_deref())
});
if reason.is_none() {
tracing::warn!(
output = %combined.trim(),
"nvidia-smi present but failing — no GPU devices discovered"
);
}
(vec![], None, reason)
(vec![], None)
}
};
@@ -272,7 +172,6 @@ pub async fn discover_system() -> Result<DiscoveryResponse> {
driver_version,
devices,
harnesses: vec![], // populated by harness registry in Phase 8
cuda_unavailable_reason,
})
}
@@ -373,63 +272,4 @@ mod tests {
assert_eq!(health[1].vram_used_mb, 4096);
assert_eq!(health[1].temp_c, 58);
}
// ── #19 driver/library mismatch preflight ────────────────────────
#[test]
fn classify_driver_mismatch_detects_and_extracts_nvml_version() {
// Verbatim shape of nvidia-smi's failure output on a host
// whose userspace libs were updated without a reboot.
let out = "Failed to initialize NVML: Driver/library version mismatch\n\
NVML library version: 580.159\n";
assert_eq!(classify_driver_mismatch(out).as_deref(), Some("580.159"));
}
#[test]
fn classify_driver_mismatch_without_version_line() {
let out = "Failed to initialize NVML: Driver/library version mismatch\n";
assert_eq!(classify_driver_mismatch(out).as_deref(), Some("unknown"));
}
#[test]
fn classify_driver_mismatch_ignores_other_failures() {
// Other nvidia-smi failures must NOT be diagnosed as the
// mismatch (no false positives on healthy or odd hosts).
for out in [
"No devices were found\n",
"Failed to initialize NVML: Insufficient Permissions\n",
"NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver.\n",
"",
] {
assert_eq!(
classify_driver_mismatch(out),
None,
"false positive on: {out:?}"
);
}
}
#[test]
fn parse_kernel_module_version_from_proc() {
let proc = "NVRM version: NVIDIA UNIX Open Kernel Module for x86_64 580.159.03 Release Build (dvs-builder@U22-I3-AE24-12-2) Tue May 12 21:03:35 UTC 2026\n\
GCC version: gcc version 15.2.1 20251022 (Red Hat 15.2.1-3) (GCC)\n";
assert_eq!(
parse_kernel_module_version(proc).as_deref(),
Some("580.159.03")
);
}
#[test]
fn parse_kernel_module_version_absent() {
assert_eq!(parse_kernel_module_version(""), None);
assert_eq!(parse_kernel_module_version("GCC version: gcc 15\n"), None);
}
#[test]
fn mismatch_reason_is_operator_actionable() {
let reason = mismatch_reason("580.159", Some("580.159.03"));
assert!(reason.contains("580.159"));
assert!(reason.contains("580.159.03"));
assert!(reason.contains("reboot"));
}
}

View File

@@ -1,23 +0,0 @@
//! Custom architecture implementations.
//!
//! When candle-transformers ships a model family unchanged
//! (`models::llama`, `models::qwen3`, `models::qwen3_moe`, etc.), the
//! handler in `harness/candle.rs` just wraps the upstream type in a
//! `ModelArch` variant.
//!
//! When candle has nothing for the architecture and we have to write
//! it from scratch — Qwen3-Next / Qwen3.6 (`qwen3_5`) being the
//! motivating example — the implementation lands here, one file per
//! architecture.
//!
//! Each architecture module is expected to expose:
//! - A `Config` type deserialised from the model's `config.json`
//! (some architectures nest the real hyperparams under `text_config`,
//! in which case the module owns the unwrapping).
//! - A `ForCausalLM` struct with `new`, `forward(&mut self, x, offset)
//! -> Result<Tensor>`, and `clear_kv_cache(&mut self)`.
//!
//! TP-aware analogues live in `harness/tp/tp_<family>.rs` and follow
//! the pattern set by `tp_qwen3.rs`.
pub mod qwen3_5;

View File

@@ -1,152 +0,0 @@
//! Qwen3-Next decoder layer.
//!
//! Standard pre-norm transformer block (LN → attention → residual →
//! LN → MLP → residual) where the attention slot dispatches on the
//! per-layer `layer_types[i]` value in the config:
//!
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
//! gate + RoPE + KV cache).
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
//! causal conv + per-head state).
//!
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
//! linear_attention. `full_attention_interval` in the config is a
//! hint; `layer_types` is authoritative.
use anyhow::Result;
use candle_core::{Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use std::sync::Arc;
use super::TextConfig;
use super::full_attn::Qwen3_5Attention;
use super::linear_attn::GatedDeltaNet;
use super::mlp::Qwen3_5MLP;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
use super::snapshot::LayerKvSnapshot;
/// One of the two attention flavours sitting in a decoder layer's
/// attention slot. Full-attention layers need the rotary table and
/// take an attention mask; linear-attention layers carry their own
/// recurrent state and ignore the mask.
enum AttentionKind {
Full(Qwen3_5Attention),
Linear(GatedDeltaNet),
}
pub struct Qwen3_5DecoderLayer {
input_layernorm: Qwen3_5RmsNorm,
post_attention_layernorm: Qwen3_5RmsNorm,
mlp: Qwen3_5MLP,
attention: AttentionKind,
}
impl Qwen3_5DecoderLayer {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
layer_idx: usize,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let layer_type = cfg
.layer_types
.get(layer_idx)
.map(String::as_str)
.ok_or_else(|| {
anyhow::anyhow!(
"layer_types[{layer_idx}] missing (have {} entries)",
cfg.layer_types.len()
)
})?;
let attention = match layer_type {
"full_attention" => {
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
}
"linear_attention" => {
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
}
other => anyhow::bail!(
"unknown layer_type '{other}' for layer {layer_idx} (expected \
'full_attention' or 'linear_attention')"
),
};
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
let input_layernorm =
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
let post_attention_layernorm = Qwen3_5RmsNorm::load(
&vb.pp("post_attention_layernorm"),
cfg.hidden_size,
cfg.rms_norm_eps,
)?;
Ok(Self {
input_layernorm,
post_attention_layernorm,
mlp,
attention,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<Tensor> {
let h = self.input_layernorm.forward(x)?;
let attn_out = match &mut self.attention {
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
// Linear attention ignores attn_mask + rope; its causal
// structure is baked into the recurrent state lifecycle.
AttentionKind::Linear(net) => net.forward(&h)?,
};
let x = (x + attn_out)?;
let h2 = self.post_attention_layernorm.forward(&x)?;
let h2 = self.mlp.forward(&h2)?;
x + h2
}
pub fn clear_kv_cache(&mut self) {
match &mut self.attention {
AttentionKind::Full(attn) => attn.clear_kv_cache(),
AttentionKind::Linear(net) => net.clear_kv_cache(),
}
}
/// Capture this layer's cache state for a prefix snapshot.
pub fn snapshot_kv(&self) -> candle_core::Result<LayerKvSnapshot> {
Ok(match &self.attention {
AttentionKind::Full(attn) => LayerKvSnapshot::Full(attn.snapshot_kv()),
AttentionKind::Linear(net) => {
let (conv_state, recurrent_state) = net.snapshot_state()?;
LayerKvSnapshot::Linear {
conv_state,
recurrent_state,
}
}
})
}
/// Replace this layer's cache state from a snapshot. The snapshot
/// variant must match the layer's attention kind — a mismatch
/// means the snapshot came from a different model.
pub fn restore_kv(&mut self, snap: &LayerKvSnapshot) -> candle_core::Result<()> {
match (&mut self.attention, snap) {
(AttentionKind::Full(attn), LayerKvSnapshot::Full(kv)) => attn.restore_kv(kv.as_ref()),
(
AttentionKind::Linear(net),
LayerKvSnapshot::Linear {
conv_state,
recurrent_state,
},
) => net.restore_state(conv_state.as_ref(), recurrent_state.as_ref()),
_ => candle_core::bail!(
"restore_kv: snapshot layer kind does not match this layer's attention kind"
),
}
}
}

View File

@@ -1,201 +0,0 @@
//! Qwen3-Next's `full_attention` layer.
//!
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
//!
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
//! to `num_heads * head_dim * 2`. The second half is reshaped to
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
//! attention output is pointwise-multiplied by this gate before
//! `o_proj`. Effectively a per-head per-position attenuation on
//! the attention output.
//!
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
//! checkpoints expect the `(1 + w)` form.
//!
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
//! `rope::RotaryEmbedding`), and the usual causal mask.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::kv_cache::ConcatKvCache;
use candle_nn::var_builder::ShardedVarBuilder;
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;
use super::TextConfig;
use super::rmsnorm::Qwen3_5RmsNorm;
use super::rope::RotaryEmbedding;
pub struct Qwen3_5Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: Qwen3_5RmsNorm,
k_norm: Qwen3_5RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary: Arc<RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl Qwen3_5Attention {
pub fn load(
cfg: &TextConfig,
rotary: Arc<RotaryEmbedding>,
vb: &ShardedVarBuilder,
) -> Result<Self> {
let head_dim = cfg.head_dim;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
anyhow::bail!(
"num_attention_heads ({num_heads}) must be a positive multiple of \
num_key_value_heads ({num_kv_heads})"
);
}
let num_kv_groups = num_heads / num_kv_heads;
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
// the gate (see attn_output_gate notes above).
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
let hidden_size = head_dim * num_heads;
let kv_cache = ConcatKvCache::new(2);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size,
rotary,
kv_cache,
})
}
pub fn forward(
&mut self,
x: &Tensor,
attn_mask: Option<&Tensor>,
cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<Tensor> {
let (b, l, _) = x.dims3()?;
// 1. q_proj — widened output, split into (query, gate).
let q_raw = self
.q_proj
.forward(x)?
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
let q = q_raw.narrow(3, 0, self.head_dim)?;
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
// Flatten the gate's head dim back into hidden_size for the
// post-attention pointwise multiply.
let gate = gate
.contiguous()?
.reshape((b, l, self.num_heads * self.head_dim))?;
// 2. q_norm + k_norm + reshape to (B, H, L, D).
let q = self.q_norm.forward(&q.contiguous()?)?;
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
let k = self
.k_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
let k = self.k_norm.forward(&k.contiguous()?)?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = self
.v_proj
.forward(x)?
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
// 3. RoPE on q, k (cos/sin built once per forward by the model —
// interleaved M-RoPE for image tokens, plain for text).
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
// 4. KV cache.
let (k, v) = self.kv_cache.append(&k, &v)?;
// 5. GQA repeat (cheap shape op).
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
// 6. Scaled dot-product + causal mask.
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?; // (B, H, L, D)
// 7. Reshape back, apply the output gate, project.
let ctx = ctx
.transpose(1, 2)?
.contiguous()?
.reshape((b, l, self.hidden_size))?;
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
let gated = (ctx * gate_sig)?;
self.o_proj.forward(&gated)
}
pub fn clear_kv_cache(&mut self) {
self.kv_cache.reset();
}
/// Capture the KV cache contents for a prefix snapshot. Shallow
/// clones: `ConcatKvCache::append` cats into fresh allocations and
/// never mutates stored tensors in place, so the captured tensors
/// stay valid after the live cache moves on.
pub fn snapshot_kv(&self) -> Option<(Tensor, Tensor)> {
match (self.kv_cache.k(), self.kv_cache.v()) {
(Some(k), Some(v)) => Some((k.clone(), v.clone())),
_ => None,
}
}
/// Replace the live KV cache with a previously captured snapshot.
pub fn restore_kv(&mut self, snap: Option<&(Tensor, Tensor)>) -> candle_core::Result<()> {
self.kv_cache.reset();
if let Some((k, v)) = snap {
self.kv_cache.append(k, v)?;
}
Ok(())
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,53 +0,0 @@
//! SwiGLU MLP block for Qwen3-Next.
//!
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
//! no bias on any of the three projections.
use anyhow::{Context, Result};
use candle_core::{Module, Tensor};
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use super::TextConfig;
pub struct Qwen3_5MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl Qwen3_5MLP {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let i = cfg.intermediate_size;
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
})
}
}
impl Module for Qwen3_5MLP {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
let rhs = self.up_proj.forward(x)?;
self.down_proj.forward(&(lhs * rhs)?)
}
}
fn load_linear_no_bias(
vb: &ShardedVarBuilder,
name: &str,
in_dim: usize,
out_dim: usize,
) -> Result<Linear> {
let weight = vb
.pp(name)
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
Ok(Linear::new(weight, None))
}

View File

@@ -1,965 +0,0 @@
//! Qwen3-Next (`model_type = "qwen3_5"`) architecture — Qwen3.6's
//! upstream architecture revision.
//!
//! ## Naming
//!
//! The model release this targets is `Qwen/Qwen3.6-*` but the
//! architecture name in HuggingFace's `config.json` is `qwen3_5`.
//! mistralrs calls the same architecture `qwen3_next`; that label
//! ages poorly the next time Qwen ship a new arch, so we key on the
//! canonical `qwen3_5` from the model's own config.
//!
//! ## Status
//!
//! **Single-GPU dense path is real**. Both attention flavours
//! (`full_attention` with the output-gated GQA causal attention and
//! `linear_attention` with the Gated DeltaNet recurrent block) are
//! implemented. The model loads from upstream safetensors via the
//! existing `load_arch_dense` dispatch and runs forward end to end.
//!
//! Numerical correctness vs the reference Python is **not yet
//! validated** — the structural code path is right, weight tensor
//! names match the upstream layout, shapes flow through cleanly, but
//! the Tbilisi probe (and any other downstream test) is the next
//! step. Likely places a bug would surface:
//! - Per-rank vs per-token-position offsets in the recurrent delta
//! rule (`linear_attn.rs`).
//! - Off-by-one in the conv state continuation across decode steps.
//! - RoPE phase mismatch from MRoPE simplification (we treat the
//! three position grids as collapsed, which is correct only for
//! text-only inference).
//!
//! ## Submodules
//!
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
//! `l2norm` helper.
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
//! rotate-half).
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
//! widening on `q_proj`.
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
//! delta rule → RMSNormGated → out_proj).
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
//! two attention flavours per layer index.
//!
//! ## Open work
//!
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
//! Sharding strategy diverges by layer type:
//! - Full-attention layers: column-parallel q/k/v (including the
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
//! `tp_qwen3.rs`.
//! - Linear-attention layers: the recurrent state is per-V-head, so
//! V-head-dimension sharding works cleanly — split `num_v_heads`
//! across ranks (`num_v_heads / world_size` per rank), shard
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
//! `dt_bias` per-head params shard with the heads.
//!
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
//! per-token recurrent path for prefill too — correct but O(L).
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
//! prefill substantially with no surface change.
use anyhow::{Context, Result};
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::Embedding;
use candle_nn::Linear;
use candle_nn::var_builder::ShardedVarBuilder;
use serde::Deserialize;
use std::sync::Arc;
pub mod decoder;
pub mod full_attn;
pub mod linear_attn;
pub mod mlp;
pub mod rmsnorm;
pub mod rope;
pub mod snapshot;
pub mod vision;
use decoder::Qwen3_5DecoderLayer;
use rmsnorm::Qwen3_5RmsNorm;
use rope::RotaryEmbedding;
/// `model_type` we deserialise from `config.json`. Const so the
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
/// magic strings.
pub const MODEL_TYPE: &str = "qwen3_5";
/// Top-level shape of Qwen3-Next's `config.json`. The real
/// hyperparameters live in `text_config`; the rest is multimodal /
/// tokeniser glue we don't need for the language-model forward.
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
/// Always `"qwen3_5"` for this architecture. Kept on the struct
/// so the (eventual) dispatch / logging code can show it without
/// re-parsing the JSON.
pub model_type: String,
/// The text-side hyperparameters. Everything we actually need.
pub text_config: TextConfig,
/// Vision tower hyperparameters. Present on multimodal
/// checkpoints (e.g. Qwen/Qwen3.6-27B); absent on text-only
/// variants. When present, `Qwen3_5ForCausalLM::new` loads the
/// vision tower alongside the language model so vision-bearing
/// requests can splice image embeddings at `<|image_pad|>` token
/// positions.
#[serde(default)]
pub vision_config: Option<vision::VisionConfig>,
/// Token id the chat template emits per image patch group.
/// Mirrors the LM tokenizer's `<|image_pad|>` id (248056 for
/// Qwen3.6). The runtime locates these in the prompt and splices
/// in `VisionTower::forward` output. `None` for text-only models.
#[serde(default)]
pub image_token_id: Option<u32>,
}
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
/// but with the extras Qwen3-Next adds (`attn_output_gate`,
/// `layer_types`, `full_attention_interval`, larger `head_dim`).
#[derive(Debug, Clone, Deserialize)]
pub struct TextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub max_position_embeddings: usize,
/// Nested RoPE settings. Qwen3-Next puts `rope_theta` and
/// `partial_rotary_factor` inside this block rather than at the
/// top level — important because the partial rotary means only
/// `head_dim * partial_rotary_factor` dims get RoPE applied (the
/// rest pass through unchanged).
pub rope_parameters: RopeParameters,
pub rms_norm_eps: f64,
#[serde(default)]
pub tie_word_embeddings: bool,
/// New in Qwen3-Next: a sigmoid gate multiplied into the attention
/// output before the o_proj. The Python reference applies it
/// pointwise after softmax+matmul.
#[serde(default)]
pub attn_output_gate: bool,
/// One entry per decoder layer; values are `"full_attention"` or
/// `"linear_attention"`. Length must equal `num_hidden_layers`.
/// `full_attention_interval` is a derived hint (every 4th layer
/// by default) — `layer_types` is authoritative.
#[serde(default)]
pub layer_types: Vec<String>,
/// Hint for the layer-type pattern (defaults to 4). Kept for
/// logging / validation; the forward dispatches on `layer_types`.
#[serde(default)]
pub full_attention_interval: Option<usize>,
/// Hidden activation (`"silu"` for Qwen3-Next). Used by the MLP
/// and the linear-attention conv1d.
#[serde(default = "default_hidden_act")]
pub hidden_act: String,
// --- Gated DeltaNet (linear-attention) hyperparams -----------------
/// Per-layer linear-attention V-head count (Qwen3.6-27B: 48).
/// More V-heads than K-heads is fine — query/key get
/// `repeat_interleave`'d to match before the delta rule.
#[serde(default)]
pub linear_num_value_heads: usize,
/// Per-layer linear-attention K-head count (Qwen3.6-27B: 16).
#[serde(default)]
pub linear_num_key_heads: usize,
/// Per-head key dimension for the linear-attention path
/// (Qwen3.6-27B: 128). Separate from `head_dim` which the
/// full-attention layers use.
#[serde(default)]
pub linear_key_head_dim: usize,
/// Per-head value dimension for the linear-attention path
/// (Qwen3.6-27B: 128).
#[serde(default)]
pub linear_value_head_dim: usize,
/// Causal Conv1d kernel size used before the delta rule
/// (Qwen3.6-27B: 4).
#[serde(default)]
pub linear_conv_kernel_dim: usize,
}
fn default_hidden_act() -> String {
"silu".into()
}
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
///
/// For text-only inference the three MRoPE position grids carry
/// identical ids, so the interleave is a no-op and plain RoPE applies.
/// For vision inputs `mrope_section` + `mrope_interleaved` drive the
/// per-axis (text/height/width) rotary used by image tokens — see
/// `rope.rs`.
#[derive(Debug, Clone, Deserialize)]
pub struct RopeParameters {
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
#[serde(default = "default_rope_theta")]
pub rope_theta: f64,
/// Fraction of `head_dim` that gets the rotation applied. The
/// remaining `head_dim * (1 - partial_rotary_factor)` dims pass
/// through unchanged. Qwen3.6 / Qwen3.5: 0.25.
#[serde(default = "default_partial_rotary_factor")]
pub partial_rotary_factor: f32,
/// `"default"` for the standard inv_freq RoPE; other values (e.g.
/// `"linear"`, `"dynamic"`) are upstream-supported but not yet
/// implemented here.
#[serde(default)]
pub rope_type: Option<String>,
/// MRoPE per-axis section sizes `[text, height, width]` — e.g.
/// `[11, 11, 10]` for Qwen3.6, summing to the rotary half-dim.
/// Empty for models that don't declare MRoPE (→ plain RoPE).
#[serde(default)]
pub mrope_section: Vec<usize>,
/// Whether the three MRoPE axes are interleaved per-frequency
/// (Qwen3-VL / Qwen3.6 style, `true`) rather than block-concatenated
/// (Qwen2-VL style, `false`).
#[serde(default)]
pub mrope_interleaved: bool,
}
fn default_rope_theta() -> f64 {
10_000.0
}
fn default_partial_rotary_factor() -> f32 {
1.0
}
/// Splice rows from `img` into `h` at `positions`. Stage B helper.
///
/// `h`: `(1, L, hidden)` — the LM's input embedding tensor after
/// `embed_tokens.forward`.
/// `img`: `(N_img, hidden)` — image embeddings, one row per
/// `<|image_pad|>` token in the prompt. Must already be in `h.dtype()`.
/// `positions`: indices into the `L` axis where image rows go;
/// `positions.len() == N_img`.
///
/// Approach: group `positions` into contiguous runs (because the chat
/// template emits `<|vision_start|><|image_pad|>×N<|vision_end|>` —
/// the pad tokens for each image land in one contiguous span), then
/// `slice_assign` per run. For typical Qwen3.6 requests this is one
/// or two runs per image; `slice_assign` does one tensor copy per
/// run, which is cheap relative to the decoder forward pass.
pub(crate) fn splice_runs(
h: &Tensor,
img: &Tensor,
positions: &[u32],
) -> candle_core::Result<Tensor> {
debug_assert!(
!positions.is_empty(),
"splice_runs precondition: non-empty positions"
);
let hidden = h.dim(2)?;
let mut out = h.clone();
let mut img_offset = 0_usize;
let mut run_start = positions[0] as usize;
let mut run_end_exclusive = run_start + 1;
for &p in &positions[1..] {
let p = p as usize;
if p == run_end_exclusive {
run_end_exclusive = p + 1;
} else {
apply_run(
&mut out,
img,
&mut img_offset,
run_start,
run_end_exclusive,
hidden,
)?;
run_start = p;
run_end_exclusive = p + 1;
}
}
apply_run(
&mut out,
img,
&mut img_offset,
run_start,
run_end_exclusive,
hidden,
)?;
Ok(out)
}
fn apply_run(
out: &mut Tensor,
img: &Tensor,
img_offset: &mut usize,
run_start: usize,
run_end_exclusive: usize,
hidden: usize,
) -> candle_core::Result<()> {
let run_len = run_end_exclusive - run_start;
let slice = img
.narrow(0, *img_offset, run_len)?
.reshape((1, run_len, hidden))?;
*out = out.slice_assign(&[0..1, run_start..run_end_exclusive, 0..hidden], &slice)?;
*img_offset += run_len;
Ok(())
}
/// Qwen3-Next base transformer (embedding + decoder stack + final
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
/// loaded handle.
pub struct Qwen3_5Model {
embed_tokens: Embedding,
layers: Vec<Qwen3_5DecoderLayer>,
norm: Qwen3_5RmsNorm,
/// Shared with every full-attention layer; the model uses it to
/// build the per-forward cos/sin (interleaved M-RoPE for image
/// tokens, plain for text) once, which the layers then apply.
rotary: Arc<RotaryEmbedding>,
/// `offset + rope_delta` is the text-axis position during decode.
/// 0 for text-only; set from `get_rope_index` during a vision
/// prefill (image tokens compress the position space, so text after
/// the image resumes from a smaller counter than the sequence
/// index). Reset in `clear_kv_cache`.
rope_delta: i64,
device: Device,
dtype: DType,
}
impl Qwen3_5Model {
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
// Qwen3-Next is a multimodal architecture whose text core lives
// under `model.language_model.*` — sibling to `model.visual.*`
// (the vision tower) and to top-level `lm_head` / `mtp.*`.
// Every text-side tensor in the safetensors files is under
// this prefix; we ignore the vision and MTP weights for
// language-model inference.
let text_vb = vb.pp("model.language_model");
let embed_vb = text_vb.pp("embed_tokens");
let embed_weight = embed_vb
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
if cfg.layer_types.len() != cfg.num_hidden_layers {
anyhow::bail!(
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
got {}",
cfg.num_hidden_layers,
cfg.layer_types.len()
);
}
let vb_l = text_vb.pp("layers");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(Qwen3_5DecoderLayer::load(
cfg,
rotary.clone(),
i,
&vb_l.pp(i),
)?);
}
let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
Ok(Self {
embed_tokens,
layers,
norm,
rotary,
rope_delta: 0,
device,
dtype,
})
}
pub fn embed_weight(&self) -> &Tensor {
self.embed_tokens.embeddings()
}
pub fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
// New request → no image-compressed position offset until the
// next vision prefill sets one.
self.rope_delta = 0;
}
/// Capture every layer's cache state plus the rope position
/// counter as one consistent prefix snapshot (#11). Only valid at
/// a token boundary — i.e. between forward calls, which is the
/// only time the caller can reach this anyway.
pub fn snapshot_kv_cache(&self) -> candle_core::Result<snapshot::KvCacheSnapshot> {
let layers = self
.layers
.iter()
.map(|l| l.snapshot_kv())
.collect::<candle_core::Result<Vec<_>>>()?;
Ok(snapshot::KvCacheSnapshot {
layers,
rope_delta: self.rope_delta,
})
}
/// Replace the live cache state with a previously captured
/// snapshot. The snapshot stays valid for further restores.
pub fn restore_kv_cache(
&mut self,
snap: &snapshot::KvCacheSnapshot,
) -> candle_core::Result<()> {
if snap.layers.len() != self.layers.len() {
candle_core::bail!(
"restore_kv_cache: snapshot has {} layers, model has {}",
snap.layers.len(),
self.layers.len()
);
}
for (layer, layer_snap) in self.layers.iter_mut().zip(snap.layers.iter()) {
layer.restore_kv(layer_snap)?;
}
self.rope_delta = snap.rope_delta;
Ok(())
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
self.forward_inner(input, offset, None, None, &[], None)
}
/// Forward for a vision-prefill chunk: optional image-embedding
/// splice plus explicit interleaved-M-RoPE `position_ids` (the
/// chunk's slice of the full prompt's 3D positions). Mirrors the TP
/// `TpQwen3_5Model::forward_with_positions` — used by
/// `Qwen3_5ForCausalLM::prefill_with_images_chunked`, which computes
/// the positions once over the whole prompt and slices them per
/// chunk so the position counters stay consistent across chunk
/// boundaries (an image compresses the position space, so per-chunk
/// offset arithmetic would be wrong).
pub fn forward_with_positions(
&mut self,
input: &Tensor,
offset: usize,
position_ids: &Tensor,
image_embeds: Option<&Tensor>,
image_token_id: Option<u32>,
) -> candle_core::Result<Tensor> {
self.forward_inner(
input,
offset,
image_embeds,
image_token_id,
&[],
Some(position_ids),
)
}
/// Forward with image-embedding splice. Stage B of the vision plan.
///
/// `input_ids`: `(1, L)` token ids — same shape the text-only
/// `forward` accepts (single-batch; multi-batch vision is not in
/// scope today).
/// `image_embeds`: `(N_image_tokens, hidden_size)` — concatenation
/// of every image's post-merger embedding (`VisionTower::forward`
/// output), in the same order images appear in the input. The
/// caller has already done the per-image patch-count expansion of
/// `<|image_pad|>` tokens in `input_ids`, so `N_image_tokens`
/// equals the number of `image_token_id` positions in `input_ids`.
/// `image_token_id`: the sentinel token (e.g. 248056 for Qwen3.6).
///
/// The splice replaces the LM's text-side embedding at each
/// `image_token_id` position with the corresponding row from
/// `image_embeds`. After the splice the decoder runs the interleaved
/// M-RoPE path: `grids` carries each image's post-merge LM grid
/// `(lm_gh, lm_gw)` so `get_rope_index` assigns image tokens their 2D
/// coordinates (dynamic resolution, #14).
pub fn forward_with_vision(
&mut self,
input_ids: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
grids: &[(usize, usize)],
) -> candle_core::Result<Tensor> {
self.forward_inner(
input_ids,
offset,
Some(image_embeds),
Some(image_token_id),
grids,
None,
)
}
/// Shared forward. Splices image embeddings at `image_token_id`
/// positions when present, then builds the rotary cos/sin, in
/// precedence order: explicit `position_ids` (interleaved M-RoPE,
/// the chunked-vision path that slices a once-computed position
/// tensor) > internal M-RoPE from `grids` (single-shot vision) >
/// plain positions at `offset + rope_delta` (text / decode).
fn forward_inner(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: Option<&Tensor>,
image_token_id: Option<u32>,
grids: &[(usize, usize)],
position_ids: Option<&Tensor>,
) -> candle_core::Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
// Splice image embeddings at `image_token_id` positions, when
// this forward carries any. Independent of how cos/sin is built.
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
for (idx, id) in ids.iter().enumerate() {
if *id == tok_id {
positions.push(idx as u32);
}
}
let n_img_tokens = img.dim(0)?;
if positions.len() != n_img_tokens {
candle_core::bail!(
"forward_with_vision: chunk has {} image-token positions but \
image_embeds carries {} tokens — per-image patch-count expansion \
/ chunk slicing mismatch",
positions.len(),
n_img_tokens,
);
}
if !positions.is_empty() {
// Cast image_embeds to the LM's dtype, then splice the
// contiguous `<|image_pad|>` runs in place.
let img = img.to_dtype(self.dtype)?;
h = splice_runs(&h, &img, &positions)?;
}
}
// Build interleaved M-RoPE cos/sin so image tokens carry their
// 2D (lm_gh × lm_gw) grid coordinates. Text / decode take the
// plain-RoPE fast path — bit-for-bit the pre-M-RoPE behaviour
// when `rope_delta == 0`.
let (cos, sin) = if let Some(pos) = position_ids {
// Pre-computed positions sliced for this chunk — the splice
// above already advanced `rope_delta`'s effect into `pos`.
self.rotary.mrope_cos_sin(pos)?
} else if let Some(tok_id) = image_token_id {
// Single-shot vision: compute the whole prompt's M-RoPE here
// and stash `rope_delta` for the decode that follows.
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id, grids)
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
self.rope_delta = delta;
let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?;
self.rotary.mrope_cos_sin(&pos)?
} else {
let base = (offset as i64 + self.rope_delta).max(0) as usize;
self.rotary.plain_cos_sin(base, l)?
};
// Causal mask only needed for L > 1 prefill; full-attention
// layers consume it via broadcast_add. Linear-attention layers
// ignore the mask.
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
}
self.norm.forward(&h)
}
}
pub struct Qwen3_5ForCausalLM {
base: Qwen3_5Model,
lm_head: Linear,
/// Vision tower (Stage A4). `None` for text-only checkpoints or
/// when the operator has opted out. When present, the harness's
/// `Job::EncodeImage` dispatch path runs `vision.forward(image)`
/// and the LM forward (Stage B) splices the result at
/// `image_token_id` positions in the input embedding stream.
vision: Option<vision::VisionTower>,
/// Mirrors `Config::image_token_id`. Cached here so the runtime
/// doesn't have to round-trip through the parsed config struct.
image_token_id: Option<u32>,
}
impl Qwen3_5ForCausalLM {
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
let cfg = &config.text_config;
let base = Qwen3_5Model::load(cfg, &vb)?;
let lm_head = if cfg.tie_word_embeddings {
Linear::new(base.embed_weight().clone(), None)
} else {
let weight = vb
.pp("lm_head")
.get((cfg.vocab_size, cfg.hidden_size), "weight")
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
Linear::new(weight, None)
};
// Stage A4: load the vision tower when the config carries a
// `vision_config` block and the safetensors actually carry
// `model.visual.*` weights. The `Option<VisionConfig>` on the
// config makes this a single-source-of-truth decision —
// text-only checkpoints just leave `vision_config` unset and
// get `None` here without any extra plumbing.
let vision = if let Some(vcfg) = config.vision_config.clone() {
tracing::info!(
depth = vcfg.depth,
hidden_size = vcfg.hidden_size,
"loading qwen3_5 vision tower"
);
Some(
vision::VisionTower::load(vcfg, vb.pp("model.visual"))
.context("load qwen3_5 vision tower (model.visual.*)")?,
)
} else {
None
};
Ok(Self {
base,
lm_head,
vision,
image_token_id: config.image_token_id,
})
}
/// True when this checkpoint loaded a vision tower. Used by the
/// HTTP layer to advertise vision capability in `/v1/models` and
/// to reject image-bearing requests against text-only loads with
/// a clean 400.
pub fn has_vision(&self) -> bool {
self.vision.is_some()
}
/// Vision tower handle, if loaded. The device-worker
/// `EncodeImage` job dispatches to `vision.forward(image)`.
pub fn vision(&self) -> Option<&vision::VisionTower> {
self.vision.as_ref()
}
/// `<|image_pad|>` token id from `config.json`, when known.
/// The Stage B prompt-builder uses this to count expansion targets
/// and the LM forward uses it to locate splice positions.
pub fn image_token_id(&self) -> Option<u32> {
self.image_token_id
}
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
/// the last position, shape `(B, 1, vocab_size)` — same contract
/// as `qwen3::ModelForCausalLM::forward` so the harness's
/// `squeeze_to_vocab` helper handles both uniformly.
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self.base.forward(input, offset)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
/// Stage B: forward with image-embedding splice. Mirrors `forward`
/// but routes through `Qwen3_5Model::forward_with_vision` so the
/// LM's input embeddings get the image patches spliced in at
/// `image_token_id` positions before the decoder stack runs.
pub fn forward_with_vision(
&mut self,
input: &Tensor,
offset: usize,
image_embeds: &Tensor,
image_token_id: u32,
grids: &[(usize, usize)],
) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden =
self.base
.forward_with_vision(input, offset, image_embeds, image_token_id, grids)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
/// Forward for a vision-prefill chunk: explicit M-RoPE positions +
/// optional image splice. Mirrors `forward_with_vision` but routes
/// through `Qwen3_5Model::forward_with_positions`. Used by
/// [`Self::prefill_with_images_chunked`].
pub fn forward_with_positions(
&mut self,
input: &Tensor,
offset: usize,
position_ids: &Tensor,
image_embeds: Option<&Tensor>,
image_token_id: Option<u32>,
) -> candle_core::Result<Tensor> {
let (_, l) = input.dims2()?;
let hidden = self.base.forward_with_positions(
input,
offset,
position_ids,
image_embeds,
image_token_id,
)?;
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
}
/// Encode every preprocessed `(C, H, W)` image once through the
/// vision tower and concatenate along the patch axis →
/// `(sum_patches, hidden)`. Done once per prefill, not per chunk.
fn encode_images_concat(&self, image_pixels: &[Tensor]) -> candle_core::Result<Tensor> {
let tower = self.vision.as_ref().ok_or_else(|| {
candle_core::Error::Msg(
"encode_images_concat: loaded without a vision tower \
(config.json::vision_config absent or weights missing)"
.into(),
)
})?;
let mut per_image = Vec::with_capacity(image_pixels.len());
for (idx, img) in image_pixels.iter().enumerate() {
let embed = tower
.forward(img)
.map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?;
per_image.push(embed);
}
Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)
}
/// Chunked image prefill for the single-GPU path (#18) — parity with
/// `TpQwen3_5ForCausalLM::prefill_with_images_chunked`. Encodes the
/// image(s) once, then walks the (pre-expanded) prompt in
/// `chunk_size`-token windows — exactly like the text
/// `chunked_prefill_*` paths — splicing the patch embeddings into
/// whichever chunk(s) carry `<|image_pad|>` positions. Activation
/// memory is bounded by the chunk, not the full prompt, so a long
/// vision context no longer single-shot-OOMs.
///
/// The KV cache (and GDN recurrent state) accumulate across chunks
/// via the growing offset — the same per-chunk associativity the
/// text chunked prefill and prefix cache (#11/#23) rely on. Only the
/// final chunk's last-position logits are returned; intermediate
/// chunks just populate the cache. The caller is responsible for
/// clearing the cache first.
///
/// `base_offset` is the KV position the prefill starts at (0 for a
/// fresh request). `image_pixels` are device-resident `(C, H, W)`
/// tensors; grids and the interleaved-M-RoPE position ids are
/// recomputed here so an image's position compression is consistent
/// across chunk boundaries.
pub fn prefill_with_images_chunked(
&mut self,
tokens: &[u32],
base_offset: usize,
image_pixels: &[Tensor],
image_token_id: u32,
chunk_size: usize,
) -> candle_core::Result<Tensor> {
if image_pixels.is_empty() {
candle_core::bail!("prefill_with_images_chunked: called with zero images");
}
if tokens.is_empty() {
candle_core::bail!("prefill_with_images_chunked: empty prompt");
}
let chunk_size = chunk_size.max(1);
let device = self.base.device.clone();
let image_embeds = self.encode_images_concat(image_pixels)?;
// Each image's LM grid (lm_gh, lm_gw) = (h/factor, w/factor),
// factor = patch×merge — recomputed from the pixel tensors (#14
// dynamic resolution).
let factor = self
.vision
.as_ref()
.map(|v| {
let c = v.config();
c.patch_size * c.spatial_merge_size
})
.ok_or_else(|| {
candle_core::Error::Msg(
"prefill_with_images_chunked: loaded without a vision tower".into(),
)
})?;
let grids: Vec<(usize, usize)> = image_pixels
.iter()
.map(|t| {
let (_, h, w) = t.dims3()?;
Ok::<(usize, usize), candle_core::Error>((h / factor, w / factor))
})
.collect::<candle_core::Result<Vec<_>>>()?;
// Interleaved-M-RoPE 3D positions for the whole prompt, computed
// once and sliced per chunk so image tokens get their grid
// coordinates and text after an image resumes from the
// compressed counter. `rope_delta` is stashed on the base model
// for the decode that follows this prefill.
let (text, height, width, delta) = rope::get_rope_index(tokens, image_token_id, &grids)
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
self.base.rope_delta = delta;
let full_pos = rope::mrope_position_tensor(&text, &height, &width, &device)?;
let mut last_logits: Option<Tensor> = None;
// Rows of `image_embeds` already spliced by earlier chunks. The
// `<|image_pad|>` run is contiguous, so chunks consume embedding
// rows in order.
let mut img_off = 0usize;
let mut start = 0usize;
while start < tokens.len() {
let end = (start + chunk_size).min(tokens.len());
let chunk = &tokens[start..end];
let input = Tensor::new(chunk, &device)?.unsqueeze(0)?;
let pos_slice = full_pos.narrow(1, start, end - start)?;
let n_here = chunk.iter().filter(|&&t| t == image_token_id).count();
let logits = if n_here == 0 {
self.forward_with_positions(&input, base_offset + start, &pos_slice, None, None)?
} else {
// Splice the next `n_here` patch rows at this chunk's
// local image-pad positions.
let rows = image_embeds.narrow(0, img_off, n_here)?;
img_off += n_here;
self.forward_with_positions(
&input,
base_offset + start,
&pos_slice,
Some(&rows),
Some(image_token_id),
)?
};
last_logits = Some(logits);
start = end;
}
last_logits
.ok_or_else(|| candle_core::Error::Msg("prefill_with_images_chunked: no chunks".into()))
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
/// See [`Qwen3_5Model::snapshot_kv_cache`].
pub fn snapshot_kv_cache(&self) -> candle_core::Result<snapshot::KvCacheSnapshot> {
self.base.snapshot_kv_cache()
}
/// See [`Qwen3_5Model::restore_kv_cache`].
pub fn restore_kv_cache(
&mut self,
snap: &snapshot::KvCacheSnapshot,
) -> candle_core::Result<()> {
self.base.restore_kv_cache(snap)
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Confirms we can deserialise the real upstream config shape.
/// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to
/// the fields the architecture cares about. Note `rope_theta` and
/// `partial_rotary_factor` are nested under `rope_parameters` —
/// Qwen3-Next does NOT have a top-level `rope_theta`.
#[test]
fn config_deserialises_the_real_qwen3_6_shape() {
let raw = r#"{
"architectures": ["Qwen3_5ForConditionalGeneration"],
"model_type": "qwen3_5",
"image_token_id": 248056,
"language_model_only": false,
"text_config": {
"vocab_size": 248064,
"hidden_size": 5120,
"intermediate_size": 17408,
"num_hidden_layers": 64,
"num_attention_heads": 64,
"num_key_value_heads": 8,
"head_dim": 256,
"max_position_embeddings": 32768,
"rope_parameters": {
"mrope_interleaved": true,
"mrope_section": [11, 11, 10],
"partial_rotary_factor": 0.25,
"rope_theta": 10000000,
"rope_type": "default"
},
"rms_norm_eps": 1e-6,
"tie_word_embeddings": false,
"attn_output_gate": true,
"full_attention_interval": 4,
"layer_types": [
"linear_attention", "linear_attention",
"linear_attention", "full_attention"
]
}
}"#;
let cfg: Config = serde_json::from_str(raw).expect("parse Qwen3.6 config");
assert_eq!(cfg.model_type, "qwen3_5");
assert_eq!(cfg.text_config.hidden_size, 5120);
assert_eq!(cfg.text_config.head_dim, 256);
assert!(cfg.text_config.attn_output_gate);
assert_eq!(cfg.text_config.full_attention_interval, Some(4));
assert_eq!(cfg.text_config.layer_types.len(), 4);
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
}
/// `splice_runs` replaces (1, L, H) embedding rows at the given
/// positions with rows from a (N_img, H) image-embedding tensor,
/// in the order positions are supplied.
#[test]
fn splice_runs_replaces_at_contiguous_positions() {
use candle_core::{DType, Device};
let dev = Device::Cpu;
// (1, L=5, H=2) text embeddings — encoded as floats so the
// assertion can spot the change without dtype conversion.
let h_vals: Vec<f32> = vec![
10., 11., // pos 0
20., 21., // pos 1
30., 31., // pos 2
40., 41., // pos 3
50., 51., // pos 4
];
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
// Two image embeddings to splice at positions 1 and 2 (a
// contiguous run — single image emitting two patch tokens).
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
let out = splice_runs(&h, &img, &[1, 2]).unwrap();
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(flat, vec![10., 11., -1., -2., -3., -4., 40., 41., 50., 51.]);
let _ = DType::F32;
}
/// Non-contiguous positions: two images at positions [1] and [3]
/// each contributing one patch. `splice_runs` should iterate
/// runs and place the corresponding image rows.
#[test]
fn splice_runs_handles_non_contiguous_runs() {
use candle_core::Device;
let dev = Device::Cpu;
let h_vals: Vec<f32> = vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.];
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
let out = splice_runs(&h, &img, &[1, 3]).unwrap();
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(flat, vec![1., 1., -1., -2., 3., 3., -3., -4., 5., 5.]);
}
}

View File

@@ -1,161 +0,0 @@
//! Norm primitives for Qwen3-Next.
//!
//! Two reasons we can't reuse `candle_nn::RmsNorm` directly:
//!
//! 1. **`(1.0 + weight)` scaling.** Qwen3-Next's `Qwen3_5RMSNorm`
//! initialises `weight` to zeros and applies `(1.0 + weight)` to
//! the normalised vector. `candle_nn::RmsNorm` applies `weight`
//! directly. The two are equivalent only when the operator has
//! pre-shifted the weights — the upstream checkpoints have not. See
//! `huggingface/transformers#29402` for the upstream PR that
//! introduced the `(1 + w)` form to recover from the zero-init.
//!
//! 2. **Gated variant.** The linear-attention layer post-normalises
//! its output by an RMSNorm *gated* with a per-element SiLU on
//! a sibling `z` projection — fused for numerical reasons (the
//! norm's float32 promotion has to happen before the SiLU
//! multiply). Not a single existing candle op.
//!
//! Both ops accept inputs in any compute dtype; promotion to f32 for
//! the variance calculation matches the Python reference.
use anyhow::{Context, Result};
use candle_core::{D, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
/// L2-normalise along the last dim with a small epsilon. Matches the
/// `l2norm` helper in `transformers/models/qwen3_5/modeling_qwen3_5.py`
/// — `x * rsqrt(sum(x*x) + eps)`. The linear-attention path uses this
/// on Q and K before the delta rule when
/// `use_qk_l2norm_in_kernel=True` (which Qwen3-Next always sets).
pub fn l2norm(x: &Tensor, eps: f32) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let sq = x_f32.sqr()?;
let sum = sq.sum_keepdim(D::Minus1)?;
let inv = (sum + eps as f64)?.sqrt()?.recip()?;
x_f32.broadcast_mul(&inv)?.to_dtype(dtype)
}
/// Qwen3-Next's RMSNorm. Stores the raw weight tensor; forward applies
/// `(1.0 + weight) * x_normed`.
pub struct Qwen3_5RmsNorm {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNorm {
/// Load `weight` from the ShardedVarBuilder. `vb` should already be
/// `.pp(...)`-ed to the norm's tensor prefix.
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
pub fn size(&self) -> usize {
self.size
}
}
impl Module for Qwen3_5RmsNorm {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
// Promote weight to f32 and shift by 1.0 *before* multiplying.
// Doing the (1 + w) operation in fp16 lands at -inf for the
// bottom-of-range weights at load time.
let w_f32 = self.weight.to_dtype(candle_core::DType::F32)?;
let scale = (w_f32 + 1.0_f64)?;
normed.broadcast_mul(&scale)?.to_dtype(dtype)
}
}
/// Gated RMSNorm used at the tail of `Qwen3_5GatedDeltaNet`. Equivalent
/// to `x_normed * weight * silu(gate)` but with both the norm and the
/// gate evaluated in float32 to avoid mid-pipeline underflow.
///
/// Note: unlike `Qwen3_5RmsNorm`, this variant matches the Python
/// reference's `Qwen3_5RMSNormGated` which uses `weight` directly (not
/// `1.0 + weight`).
pub struct Qwen3_5RmsNormGated {
weight: Tensor,
eps: f32,
size: usize,
}
impl Qwen3_5RmsNormGated {
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
Ok(Self {
weight,
eps: eps as f32,
size,
})
}
/// Direct constructor — used by unit tests that build a layer
/// without going through a VarBuilder.
#[cfg(test)]
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
let size = weight.dims()[0];
Self {
weight,
eps: eps as f32,
size,
}
}
pub fn size(&self) -> usize {
self.size
}
/// `x` and `gate` share the same last-dim shape (`size`).
pub fn forward(&self, x: &Tensor, gate: &Tensor) -> candle_core::Result<Tensor> {
let dtype = x.dtype();
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
let w = self.weight.to_dtype(candle_core::DType::F32)?;
let out = normed.broadcast_mul(&w)?;
// SiLU on the float32 gate, multiply back into the normed
// tensor, then cast to the model dtype.
let g = gate.to_dtype(candle_core::DType::F32)?;
let silu_gate = candle_nn::ops::silu(&g)?;
(out * silu_gate)?.to_dtype(dtype)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn l2norm_matches_hand_calc() {
let x = Tensor::new(&[3.0_f32, 4.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
// |x| = 5, so x/|x| = [0.6, 0.8] (eps is tiny).
assert!((v[0] - 0.6).abs() < 1e-4);
assert!((v[1] - 0.8).abs() < 1e-4);
}
#[test]
fn l2norm_zero_vector_is_safe_via_epsilon() {
let x = Tensor::new(&[0.0_f32, 0.0_f32], &Device::Cpu).unwrap();
let out = l2norm(&x, 1e-6).unwrap();
let v: Vec<f32> = out.to_vec1().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
}

View File

@@ -1,579 +0,0 @@
//! Rotary position embedding for Qwen3-Next's full-attention layers.
//!
//! Qwen3.6 declares **interleaved M-RoPE** (multimodal RoPE): the
//! rotary half-dimension is split across three position axes —
//! `[text, height, width]` per `mrope_section` (`[11,11,10]` for
//! Qwen3.6) — interleaved per-frequency. For **text** every token's
//! three axes carry the same position id, so the interleave is a no-op
//! and this reduces exactly to plain RoPE. For **image** tokens the
//! height/width axes carry the patch's 2D grid coordinates, which is
//! how the model reads the 14×14 patch layout (without it, all patches
//! share a height position and the image reads as vertical repetition).
//!
//! Two cos/sin builders feed a shared [`RotaryEmbedding::apply`]:
//! - [`RotaryEmbedding::plain_cos_sin`] narrows the precomputed tables
//! at a scalar position — the text / decode fast path.
//! - [`RotaryEmbedding::mrope_cos_sin`] builds per-token cos/sin from a
//! `(3, seq)` position-id tensor, blending the three axes' frequencies
//! at the interleave index sets — the vision-prefill path.
//!
//! Rotation flavour: **GLM-style** rotate-half (candle's `rope_slow`),
//! matching the reference Python's `apply_rotary_pos_emb` + `rotate_half`.
use anyhow::Result;
use candle_core::{DType, Device, IndexOp, Tensor};
use super::TextConfig;
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
/// Inverse frequencies, shape `(1, rotary_dim/2)`. Retained (beyond
/// the precomputed `sin`/`cos` tables) so [`Self::mrope_cos_sin`] can
/// build cos/sin from arbitrary per-axis position ids.
inv_freq: Tensor,
/// Per-axis column masks over the rotary half-dim, shape `(1, half)`,
/// f32 0/1. `mask_t + mask_h + mask_w` partitions the columns; a
/// column belongs to exactly one axis. For a non-MRoPE config
/// `mask_t` is all-ones and the others all-zero (→ plain RoPE).
mask_t: Tensor,
mask_h: Tensor,
mask_w: Tensor,
dtype: DType,
/// Number of dims at the head's leading edge that the rotation
/// covers. The remaining `head_dim - rotary_dim` dims pass through
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
/// for `head_dim = 256` only 64 dims rotate.
rotary_dim: usize,
head_dim: usize,
}
/// Build the per-axis 0/1 column masks over the rotary half-dim from
/// `mrope_section`. Returns `(temporal, height, width)` each length
/// `half`. Temporal is the complement of height width, so the three
/// masks always partition `0..half` and reduce to all-temporal (plain
/// RoPE) when no usable section is given.
fn mrope_masks(
half: usize,
section: &[usize],
interleaved: bool,
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let mut mh = vec![0f32; half];
let mut mw = vec![0f32; half];
if section.len() == 3 {
if interleaved {
// Qwen3-VL: height at columns 1,4,7,… ; width at 2,5,8,… ;
// temporal keeps 0,3,6,… — each `take`n from `mrope_section`.
for i in (1..half).step_by(3).take(section[1]) {
mh[i] = 1.0;
}
for i in (2..half).step_by(3).take(section[2]) {
mw[i] = 1.0;
}
} else {
// Qwen2-VL: contiguous blocks [text | height | width].
let h_start = section[0].min(half);
let h_end = (section[0] + section[1]).min(half);
for m in mh.iter_mut().take(h_end).skip(h_start) {
*m = 1.0;
}
for m in mw.iter_mut().take(half).skip(h_end) {
*m = 1.0;
}
}
}
let mt: Vec<f32> = (0..half)
.map(|i| {
if mh[i] == 0.0 && mw[i] == 0.0 {
1.0
} else {
0.0
}
})
.collect();
(mt, mh, mw)
}
impl RotaryEmbedding {
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
let head_dim = cfg.head_dim;
let rope = &cfg.rope_parameters;
let rotary_dim = (head_dim as f32 * rope.partial_rotary_factor) as usize;
if !rotary_dim.is_multiple_of(2) {
anyhow::bail!(
"rotary_dim = head_dim * partial_rotary_factor = {head_dim} * {} = {rotary_dim} \
must be even (cos/sin are paired)",
rope.partial_rotary_factor
);
}
if rotary_dim == 0 {
anyhow::bail!(
"rotary_dim = 0 (partial_rotary_factor = {} too small)",
rope.partial_rotary_factor
);
}
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<f32> = (0..rotary_dim)
.step_by(2)
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
.collect();
let half = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, half), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
// MRoPE axis masks. `sum(mrope_section)` should equal `half`;
// warn-tolerant: any shortfall just stays on the temporal axis.
let (mt, mh, mw) = mrope_masks(half, &rope.mrope_section, rope.mrope_interleaved);
let mask_t = Tensor::from_vec(mt, (1, half), dev)?;
let mask_h = Tensor::from_vec(mh, (1, half), dev)?;
let mask_w = Tensor::from_vec(mw, (1, half), dev)?;
Ok(Self {
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
inv_freq,
mask_t,
mask_h,
mask_w,
dtype,
rotary_dim,
head_dim,
})
}
/// cos/sin for a contiguous run of `seq_len` positions starting at
/// `pos`, by narrowing the precomputed tables. The text / decode
/// path (all three MRoPE axes equal → plain RoPE). Shape
/// `(seq_len, rotary_dim/2)`.
pub fn plain_cos_sin(
&self,
pos: usize,
seq_len: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
let cos = self.cos.narrow(0, pos, seq_len)?;
let sin = self.sin.narrow(0, pos, seq_len)?;
Ok((cos, sin))
}
/// cos/sin from explicit per-token 3D position ids, shape
/// `(3, seq_len)` (axes: text, height, width). Builds each axis's
/// frequencies and blends them at the interleave index sets, so
/// every rotary frequency slot is driven by exactly one axis.
/// Reduces exactly to [`Self::plain_cos_sin`] when the three axes are
/// equal. Returns cos/sin of shape `(seq_len, rotary_dim/2)`.
pub fn mrope_cos_sin(&self, position_ids: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
let pos = position_ids.to_dtype(DType::F32)?;
let (axes, seq_len) = pos.dims2()?;
debug_assert_eq!(axes, 3, "mrope position_ids must have 3 axes");
// Per-axis freqs: pos[a] (seq,1) @ inv_freq (1,half) → (seq,half).
let ft = pos.i(0)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
let fh = pos.i(1)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
let fw = pos.i(2)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
// Blend: each column belongs to exactly one axis (masks partition
// the half-dim), so this picks the right axis per frequency slot.
let blended = ft
.broadcast_mul(&self.mask_t)?
.add(&fh.broadcast_mul(&self.mask_h)?)?
.add(&fw.broadcast_mul(&self.mask_w)?)?;
let cos = blended.cos()?.to_dtype(self.dtype)?;
let sin = blended.sin()?.to_dtype(self.dtype)?;
Ok((cos, sin))
}
/// Apply rotary to `q`, `k` (shape `(B, H, L, head_dim)`) using
/// precomputed `cos`/`sin` of shape `(L, rotary_dim/2)`. Partial
/// rotary: only the first `rotary_dim` dims rotate; the tail passes
/// through unchanged.
pub fn apply_cos_sin(
&self,
q: &Tensor,
k: &Tensor,
cos: &Tensor,
sin: &Tensor,
) -> candle_core::Result<(Tensor, Tensor)> {
let (_, _, _seq_len, head_dim_in) = q.dims4()?;
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
if self.rotary_dim == self.head_dim {
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, cos, sin)?;
Ok((q_embed, k_embed))
} else {
// Partial rotation: narrow → rotate → cat the untouched tail.
let tail = self.head_dim - self.rotary_dim;
let q_rot = q
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
.contiguous()?;
let q_pass = q.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
let k_rot = k
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
.contiguous()?;
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, cos, sin)?;
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, cos, sin)?;
let q_embed =
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
let k_embed =
Tensor::cat(&[&k_rotated, &k_pass.contiguous()?], candle_core::D::Minus1)?;
Ok((q_embed, k_embed))
}
}
}
/// Compute interleaved-M-RoPE 3D position ids for a full prompt that may
/// contain image-placeholder runs, plus the decode `rope_delta`.
///
/// Mirrors the reference `get_rope_index`:
/// - text tokens advance a single running counter `c`, all three axes
/// equal (`[c, c, c]`);
/// - each contiguous run of `image_token_id` is one image; its tokens get
/// `[base + t, base + h, base + w]` in row-major (t outer, h, w inner),
/// where `base` is the counter at the run's start; after the run the
/// counter resumes from `base + max(grid_t, grid_h, grid_w)`.
///
/// Returns `(text_pos, height_pos, width_pos, rope_delta)`, each pos `Vec`
/// length `input_ids.len()`. `rope_delta = final_counter - seq_len`: add it
/// to a plain decode offset so text resumes from the counter after the
/// (position-compressed) image blocks.
///
/// Whether interleaved M-RoPE for image tokens is enabled. Default
/// **on** — Qwen3.6 was trained with interleaved M-RoPE, and this
/// implementation matches the HF `apply_interleaved_mrope` /
/// `get_rope_index` reference exactly (verified column-for-column). The
/// env var is a **kill switch**: `NEURON_MROPE=0` falls back to plain
/// sequential positions for image tokens (the pre-M-RoPE behaviour).
pub(crate) fn mrope_enabled() -> bool {
std::env::var("NEURON_MROPE")
.map(|v| {
!matches!(
v.trim().to_ascii_lowercase().as_str(),
"0" | "false" | "no" | "off"
)
})
.unwrap_or(true)
}
/// Position ids for the forward path. Gated by [`mrope_enabled`]: when
/// off, returns plain sequential identity positions on all three axes
/// (`mrope_cos_sin` then reduces exactly to plain RoPE), restoring the
/// pre-M-RoPE behaviour without touching the rest of the forward.
pub(crate) fn get_rope_index(
input_ids: &[u32],
image_token_id: u32,
grids: &[(usize, usize)],
) -> Result<MRopeIndex> {
if !mrope_enabled() {
let seq: Vec<i64> = (0..input_ids.len() as i64).collect();
return Ok((seq.clone(), seq.clone(), seq, 0));
}
compute_mrope_index(input_ids, image_token_id, grids)
}
/// The real interleaved-M-RoPE position-id computation (always active in
/// unit tests; gated behind [`get_rope_index`] at runtime).
///
/// `grids` carries the post-merge LM grid `(lm_gh, lm_gw)` for each image
/// run, in prompt order — a run length alone cannot recover its
/// factorisation, so the grids must be passed (#14 dynamic resolution).
/// Each image is a still frame (`grid_t = 1`); its tokens get
/// `[base, base + hh, base + ww]` row-major and the shared counter
/// resumes at `base + max(lm_gh, lm_gw)`. Multi-image is correct because
/// the counter threads across images and interleaved text.
pub(crate) fn compute_mrope_index(
input_ids: &[u32],
image_token_id: u32,
grids: &[(usize, usize)],
) -> Result<MRopeIndex> {
let n = input_ids.len();
let mut text = Vec::with_capacity(n);
let mut height = Vec::with_capacity(n);
let mut width = Vec::with_capacity(n);
let mut counter: i64 = 0;
let mut i = 0;
let mut k = 0; // index into `grids`, one per image run
while i < n {
if input_ids[i] == image_token_id {
let start = i;
while i < n && input_ids[i] == image_token_id {
i += 1;
}
let run = i - start;
let (grid_h, grid_w) = *grids.get(k).ok_or_else(|| {
anyhow::anyhow!(
"get_rope_index: image run #{k} (len {run}) has no matching grid \
({} grids supplied)",
grids.len()
)
})?;
k += 1;
if grid_h * grid_w != run {
anyhow::bail!(
"get_rope_index: image run #{} length {run} != grid {grid_h}×{grid_w} = {}",
k - 1,
grid_h * grid_w
);
}
let base = counter;
for hh in 0..grid_h {
for ww in 0..grid_w {
text.push(base); // grid_t = 1 → temporal axis const
height.push(base + hh as i64);
width.push(base + ww as i64);
}
}
counter = base + grid_h.max(grid_w) as i64;
} else {
text.push(counter);
height.push(counter);
width.push(counter);
counter += 1;
i += 1;
}
}
if k != grids.len() {
anyhow::bail!(
"get_rope_index: prompt has {k} image run(s) but {} grid(s) were supplied",
grids.len()
);
}
let delta = counter - n as i64;
Ok((text, height, width, delta))
}
/// `(text_pos, height_pos, width_pos, rope_delta)` returned by
/// [`get_rope_index`]; the three vectors combine into the `(3, seq)`
/// MRoPE position-id tensor.
pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
/// Build the `(3, seq)` position-id tensor consumed by
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
///
/// Built directly as **f32** (positions are small integers, exact in
/// f32 well past any context length): the freqs matmul needs float
/// anyway, and this avoids an i64 tensor / i64→f32 cast on the GPU.
pub(crate) fn mrope_position_tensor(
text: &[i64],
height: &[i64],
width: &[i64],
dev: &Device,
) -> candle_core::Result<Tensor> {
let seq = text.len();
let mut flat = Vec::with_capacity(3 * seq);
flat.extend(text.iter().map(|&x| x as f32));
flat.extend(height.iter().map(|&x| x as f32));
flat.extend(width.iter().map(|&x| x as f32));
Tensor::from_vec(flat, (3, seq), dev)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::IndexOp;
/// A TextConfig stub with Qwen3.6's rope params (head_dim 256,
/// partial 0.25 → rotary_dim 64 → half 32; section [11,11,10]).
fn qwen36_cfg() -> TextConfig {
serde_json::from_value(serde_json::json!({
"hidden_size": 5120,
"num_hidden_layers": 1,
"num_attention_heads": 64,
"num_key_value_heads": 8,
"head_dim": 256,
"intermediate_size": 1,
"vocab_size": 10,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 64,
"layer_types": ["full_attention"],
"rope_parameters": {
"rope_theta": 10000000.0,
"partial_rotary_factor": 0.25,
"mrope_section": [11, 11, 10],
"mrope_interleaved": true
}
}))
.expect("cfg")
}
#[test]
fn mrope_masks_partition_the_half_dim() {
let (mt, mh, mw) = mrope_masks(32, &[11, 11, 10], true);
// Each column belongs to exactly one axis.
for i in 0..32 {
let s = mt[i] + mh[i] + mw[i];
assert_eq!(s, 1.0, "column {i} covered {s} times");
}
assert_eq!(mt.iter().sum::<f32>(), 11.0);
assert_eq!(mh.iter().sum::<f32>(), 11.0);
assert_eq!(mw.iter().sum::<f32>(), 10.0);
// Interleave: temporal 0,3,…; height 1,4,…; width 2,5,…
assert_eq!(mt[0], 1.0);
assert_eq!(mh[1], 1.0);
assert_eq!(mw[2], 1.0);
assert_eq!(mt[3], 1.0);
}
/// The load-bearing invariant: when all three position axes are
/// equal (text), `mrope_cos_sin` must reproduce `plain_cos_sin`
/// bit-for-bit — i.e. M-RoPE is a no-op for text, so text inference
/// is unchanged.
#[test]
fn mrope_reduces_to_plain_for_equal_axes() {
let dev = Device::Cpu;
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
// positions 5,6,7 on all three axes.
let base: Vec<i64> = vec![5, 6, 7];
let pos =
Tensor::from_vec([base.clone(), base.clone(), base].concat(), (3, 3), &dev).unwrap();
let (mc, ms) = rope.mrope_cos_sin(&pos).unwrap();
let (pc, ps) = rope.plain_cos_sin(5, 3).unwrap();
let dcos = (mc - pc).unwrap().abs().unwrap().max_all().unwrap();
let dsin = (ms - ps).unwrap().abs().unwrap().max_all().unwrap();
assert!(
dcos.to_scalar::<f32>().unwrap() < 1e-6,
"cos mismatch {dcos:?}"
);
assert!(
dsin.to_scalar::<f32>().unwrap() < 1e-6,
"sin mismatch {dsin:?}"
);
}
/// Hand-checked interleave: a width-axis column (index 2) must track
/// the WIDTH position, while a temporal column (index 0) tracks the
/// TEXT position, even when the axes differ.
#[test]
fn mrope_blends_axes_at_interleave_columns() {
let dev = Device::Cpu;
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
let half = rope.inv_freq.dim(1).unwrap();
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
// One token: text=10, height=3, width=7 — all distinct.
let pos = Tensor::from_vec(vec![10i64, 3, 7], (3, 1), &dev).unwrap();
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
let cos_row: Vec<f32> = cos.i(0).unwrap().to_vec1().unwrap();
assert_eq!(cos_row.len(), half);
// Column 0 (temporal) → text pos 10. Column 1 (height) → 3.
// Column 2 (width) → 7.
assert!((cos_row[0] - (10.0 * inv[0]).cos()).abs() < 1e-5);
assert!((cos_row[1] - (3.0 * inv[1]).cos()).abs() < 1e-5);
assert!((cos_row[2] - (7.0 * inv[2]).cos()).abs() < 1e-5);
assert!((cos_row[3] - (10.0 * inv[3]).cos()).abs() < 1e-5);
}
#[test]
fn get_rope_index_text_only_is_sequential() {
let (t, h, w, delta) = compute_mrope_index(&[1, 2, 3, 4], 99, &[]).unwrap();
assert_eq!(t, vec![0, 1, 2, 3]);
assert_eq!(h, vec![0, 1, 2, 3]);
assert_eq!(w, vec![0, 1, 2, 3]);
assert_eq!(delta, 0, "no image → delta 0 → plain decode positions");
}
#[test]
fn get_rope_index_text_image_text() {
// [text, image(2x2 run of 4), text]. image_token = 99, grid (2,2).
let ids = [1u32, 99, 99, 99, 99, 2];
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 2)]).unwrap();
// token 0: text → 0. image base=1, grid 2x2:
// t all = 1; h = base+row = [1,1,2,2]; w = base+col = [1,2,1,2].
// resume from base + max(2,2) = 3. trailing text → 3.
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
// final counter = 4, seq_len = 6 → delta = -2 (the 4 image tokens
// advanced the counter by only 2).
assert_eq!(delta, -2);
// Decode after the prompt (offset = 6) → text position 6 + (-2) = 4.
assert_eq!(6 + delta, 4);
}
#[test]
fn get_rope_index_nonsquare_single_image() {
// text + image(2 rows × 3 cols = 6 tokens). grid (2,3).
let ids = [1u32, 99, 99, 99, 99, 99, 99];
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 3)]).unwrap();
// base = 1; row-major h = [0,0,0,1,1,1]+1, w = [0,1,2,0,1,2]+1.
assert_eq!(t, vec![0, 1, 1, 1, 1, 1, 1]);
assert_eq!(h, vec![0, 1, 1, 1, 2, 2, 2]);
assert_eq!(w, vec![0, 1, 2, 3, 1, 2, 3]);
// resume from base + max(2,3) = 4; seq_len 7, counter 4 → delta -3.
assert_eq!(delta, 4 - 7);
}
#[test]
fn get_rope_index_two_images_different_grids() {
// img(2x2)=4, text, img(1x3)=3. grids [(2,2),(1,3)].
let ids = [99, 99, 99, 99, 7, 99, 99, 99];
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 2), (1, 3)]).unwrap();
// img1 base=0 → t=0, h=[0,0,1,1], w=[0,1,0,1]; resume max(2,2)=2.
// text at counter 2. img2 base=3 → t=3, h=[3,3,3], w=[3,4,5];
// resume 3+max(1,3)=6.
assert_eq!(t, vec![0, 0, 0, 0, 2, 3, 3, 3]);
assert_eq!(h, vec![0, 0, 1, 1, 2, 3, 3, 3]);
assert_eq!(w, vec![0, 1, 0, 1, 2, 3, 4, 5]);
assert_eq!(delta, 6 - 8);
}
#[test]
fn get_rope_index_on_by_default() {
// With NEURON_MROPE unset (default ON), the runtime path returns
// the real interleaved-M-RoPE positions. (NEURON_MROPE=0 would fall
// back to identity; not asserted here since it depends on env.)
let (t, h, w, _delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99, &[(2, 2)]).unwrap();
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
}
#[test]
fn get_rope_index_grid_mismatches_error() {
// run length != grid product.
assert!(compute_mrope_index(&[99u32; 6], 99, &[(2, 2)]).is_err());
// too few grids for the number of image runs.
assert!(compute_mrope_index(&[99, 99, 7, 99], 99, &[(1, 2)]).is_err());
// too many grids.
assert!(compute_mrope_index(&[99, 99], 99, &[(1, 2), (1, 1)]).is_err());
}
#[test]
fn position_tensor_round_trips_through_mrope_cos_sin() {
// get_rope_index → (3,seq) tensor → mrope_cos_sin, and confirm an
// image token's height column tracks its grid row (not the text
// counter), i.e. the end-to-end position plumbing is wired right.
let dev = Device::Cpu;
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
let ids = [1u32, 99, 99, 99, 99]; // text + 2x2 image
let (t, h, w, _d) = compute_mrope_index(&ids, 99, &[(2, 2)]).unwrap();
let pos = mrope_position_tensor(&t, &h, &w, &dev).unwrap();
assert_eq!(pos.dims(), &[3, 5]);
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
assert_eq!(cos.dims(), &[5, rope.inv_freq.dim(1).unwrap()]);
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
// Last image token (index 4): grid (h=1, w=1) → base 1 → h=2, w=2.
// Height column (index 1) must track h-position 2, not text.
let last: Vec<f32> = cos.i(4).unwrap().to_vec1().unwrap();
assert!((last[1] - (2.0 * inv[1]).cos()).abs() < 1e-5);
}
#[test]
fn get_rope_index_196_is_14x14() {
let mut ids = vec![1u32]; // one text token
ids.extend(std::iter::repeat_n(99u32, 196));
let (t, h, w, _delta) = compute_mrope_index(&ids, 99, &[(14, 14)]).unwrap();
// image base = 1. Last image token (index 196) is grid (h=13,w=13).
assert_eq!(*t.last().unwrap(), 1, "grid_t=1 → temporal const at base");
assert_eq!(h[1], 1, "first image row at base");
assert_eq!(w[1], 1, "first image col at base");
assert_eq!(h[196], 1 + 13, "last image row = base + 13");
assert_eq!(w[196], 1 + 13, "last image col = base + 13");
}
}

View File

@@ -1,299 +0,0 @@
//! Cache-state snapshots for prefix KV caching (#11).
//!
//! A snapshot captures everything `clear_kv_cache` would destroy, at
//! one consistent token boundary:
//!
//! - full-attention layers: the `ConcatKvCache` k/v tensors,
//! - linear-attention layers: the GatedDeltaNet `conv_state` +
//! `recurrent_state`,
//! - the model-level `rope_delta` position counter.
//!
//! The GatedDeltaNet recurrent state cannot be rewound to an earlier
//! token, so a snapshot is only reusable when its entire token
//! sequence is an exact prefix of an incoming prompt — matching policy
//! lives in `harness/prefix_cache.rs`; this module is just the state
//! capture.
//!
//! ## Copy semantics
//!
//! Attention k/v snapshots share storage with the live cache:
//! `ConcatKvCache::append` never mutates stored tensors in place (it
//! `cat`s into fresh allocations), so a shallow `Tensor` clone stays
//! valid after the live cache moves on. The GDN states are
//! **deep-copied** in both directions (`Tensor::copy`): the CUDA
//! delta-rule kernels update the recurrent-state buffer in place, and
//! `flatten`/`contiguous` on an already-contiguous tensor is a view —
//! a shared-storage snapshot would be corrupted by the next forward.
use candle_core::Tensor;
/// Per-layer captured state. Variant kind must match the layer's
/// `AttentionKind` on restore.
pub enum LayerKvSnapshot {
/// `ConcatKvCache` contents. `None` when the cache was empty
/// (a zero-token snapshot — valid but useless; the registry never
/// stores one).
Full(Option<(Tensor, Tensor)>),
/// GatedDeltaNet state. Either tensor is `None` before the first
/// forward touches it.
Linear {
conv_state: Option<Tensor>,
recurrent_state: Option<Tensor>,
},
}
/// One consistent cache snapshot of a `Qwen3_5Model` (or its TP
/// mirror `tp_qwen3_5::TpQwen3_5Model`, whose per-rank shard state
/// has the same shape) at a token boundary. Fields are `pub(crate)`
/// so the TP module can construct/consume the same type; holders
/// outside the harness only ever pass it back to `restore_kv_cache`.
pub struct KvCacheSnapshot {
pub(crate) layers: Vec<LayerKvSnapshot>,
pub(crate) rope_delta: i64,
}
impl KvCacheSnapshot {
/// Number of layer snapshots held (test/diagnostic helper).
pub fn layer_count(&self) -> usize {
self.layers.len()
}
/// Total bytes of tensor data held by this snapshot. Used for the
/// prefix-cache VRAM budget. Attention k/v shares storage with the
/// live cache at capture time, but the live cache is cleared or
/// replaced before the next request, so counting the full size is
/// the honest steady-state figure.
pub fn size_bytes(&self) -> u64 {
fn t_bytes(t: &Tensor) -> u64 {
(t.elem_count() * t.dtype().size_in_bytes()) as u64
}
self.layers
.iter()
.map(|l| match l {
LayerKvSnapshot::Full(Some((k, v))) => t_bytes(k) + t_bytes(v),
LayerKvSnapshot::Full(None) => 0,
LayerKvSnapshot::Linear {
conv_state,
recurrent_state,
} => {
conv_state.as_ref().map(t_bytes).unwrap_or(0)
+ recurrent_state.as_ref().map(t_bytes).unwrap_or(0)
}
})
.sum()
}
}
#[cfg(test)]
mod tests {
use super::super::{Qwen3_5Model, RopeParameters, TextConfig};
use candle_core::{DType, Device, Tensor};
use std::collections::HashMap;
/// Tiny two-layer config covering both attention kinds.
fn tiny_config() -> TextConfig {
TextConfig {
vocab_size: 32,
hidden_size: 16,
intermediate_size: 32,
num_hidden_layers: 2,
num_attention_heads: 2,
num_key_value_heads: 1,
head_dim: 8,
max_position_embeddings: 64,
rope_parameters: RopeParameters {
rope_theta: 10000.0,
partial_rotary_factor: 0.5,
rope_type: None,
mrope_section: Vec::new(),
mrope_interleaved: false,
},
rms_norm_eps: 1e-6,
tie_word_embeddings: true,
attn_output_gate: true,
layer_types: vec!["linear_attention".into(), "full_attention".into()],
full_attention_interval: Some(4),
hidden_act: "silu".into(),
linear_num_value_heads: 4,
linear_num_key_heads: 2,
linear_key_head_dim: 4,
linear_value_head_dim: 4,
linear_conv_kernel_dim: 4,
}
}
/// Build a Qwen3_5Model from random weights written to a temp
/// safetensors file — the same `ShardedVarBuilder` path the real
/// loader uses.
fn tiny_model(cfg: &TextConfig) -> Qwen3_5Model {
let dev = Device::Cpu;
let randn = |shape: &[usize]| Tensor::randn(0f32, 0.2f32, shape, &dev).unwrap();
let h = cfg.hidden_size;
let inter = cfg.intermediate_size;
let key_dim = cfg.linear_key_head_dim * cfg.linear_num_key_heads;
let value_dim = cfg.linear_value_head_dim * cfg.linear_num_value_heads;
let conv_dim = key_dim * 2 + value_dim;
let nv = cfg.linear_num_value_heads;
let hd = cfg.head_dim;
let q_out = cfg.num_attention_heads * hd * 2;
let kv_out = cfg.num_key_value_heads * hd;
let mut t: HashMap<String, Tensor> = HashMap::new();
let p = "model.language_model";
t.insert(
format!("{p}.embed_tokens.weight"),
randn(&[cfg.vocab_size, h]),
);
t.insert(format!("{p}.norm.weight"), randn(&[h]));
for (i, kind) in cfg.layer_types.iter().enumerate() {
let lp = format!("{p}.layers.{i}");
t.insert(format!("{lp}.input_layernorm.weight"), randn(&[h]));
t.insert(format!("{lp}.post_attention_layernorm.weight"), randn(&[h]));
t.insert(format!("{lp}.mlp.gate_proj.weight"), randn(&[inter, h]));
t.insert(format!("{lp}.mlp.up_proj.weight"), randn(&[inter, h]));
t.insert(format!("{lp}.mlp.down_proj.weight"), randn(&[h, inter]));
match kind.as_str() {
"linear_attention" => {
let ap = format!("{lp}.linear_attn");
t.insert(format!("{ap}.in_proj_qkv.weight"), randn(&[conv_dim, h]));
t.insert(format!("{ap}.in_proj_z.weight"), randn(&[value_dim, h]));
t.insert(format!("{ap}.in_proj_b.weight"), randn(&[nv, h]));
t.insert(format!("{ap}.in_proj_a.weight"), randn(&[nv, h]));
t.insert(format!("{ap}.out_proj.weight"), randn(&[h, value_dim]));
t.insert(
format!("{ap}.conv1d.weight"),
randn(&[conv_dim, 1, cfg.linear_conv_kernel_dim]),
);
t.insert(format!("{ap}.dt_bias"), randn(&[nv]));
t.insert(format!("{ap}.A_log"), randn(&[nv]));
t.insert(
format!("{ap}.norm.weight"),
randn(&[cfg.linear_value_head_dim]),
);
}
"full_attention" => {
let ap = format!("{lp}.self_attn");
t.insert(format!("{ap}.q_proj.weight"), randn(&[q_out, h]));
t.insert(format!("{ap}.k_proj.weight"), randn(&[kv_out, h]));
t.insert(format!("{ap}.v_proj.weight"), randn(&[kv_out, h]));
t.insert(
format!("{ap}.o_proj.weight"),
randn(&[h, cfg.num_attention_heads * hd]),
);
t.insert(format!("{ap}.q_norm.weight"), randn(&[hd]));
t.insert(format!("{ap}.k_norm.weight"), randn(&[hd]));
}
other => panic!("unexpected layer type {other}"),
}
}
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("model.safetensors");
candle_core::safetensors::save(&t, &path).expect("save safetensors");
// SAFETY: mmap of a file this test just wrote and nothing else
// mutates — same justification as the real loader.
let vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(
&[path.clone()],
DType::F32,
&dev,
)
.expect("build ShardedVarBuilder")
};
Qwen3_5Model::load(cfg, &vb).expect("load tiny qwen3_5 model")
}
fn forward_tokens(model: &mut Qwen3_5Model, tokens: &[u32], offset: usize) -> Vec<f32> {
let input = Tensor::new(tokens, &Device::Cpu)
.unwrap()
.unsqueeze(0)
.unwrap();
let hidden = model.forward(&input, offset).unwrap();
// Last-position hidden row — what the lm_head would consume.
let (_, l, _) = hidden.dims3().unwrap();
hidden
.narrow(1, l - 1, 1)
.unwrap()
.flatten_all()
.unwrap()
.to_vec1()
.unwrap()
}
fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs())
.fold(0f32, f32::max)
}
/// The gold test for #11: prefill a prefix, snapshot, perturb the
/// live state with unrelated tokens, restore, prefill only the
/// suffix — the result must match a fresh full prefill. Exercises
/// attention KV, GDN conv/recurrent state, and offset bookkeeping
/// in one pass; the perturbation step would corrupt a
/// shared-storage (non-deep-copied) GDN snapshot.
#[test]
fn restore_then_suffix_matches_full_prefill() {
let cfg = tiny_config();
let mut model = tiny_model(&cfg);
let prefix: &[u32] = &[1, 2, 3];
let suffix: &[u32] = &[4, 5, 6];
let full: Vec<u32> = prefix.iter().chain(suffix).copied().collect();
model.clear_kv_cache();
let h_full = forward_tokens(&mut model, &full, 0);
model.clear_kv_cache();
forward_tokens(&mut model, prefix, 0);
let snap = model.snapshot_kv_cache().expect("snapshot");
assert_eq!(snap.layer_count(), 2);
assert!(snap.size_bytes() > 0);
// Advance the live state past the snapshot boundary — a
// different continuation, as a subsequent request would be.
forward_tokens(&mut model, &[9, 8], prefix.len());
model.restore_kv_cache(&snap).expect("restore");
let h_restored = forward_tokens(&mut model, suffix, prefix.len());
let diff = max_abs_diff(&h_full, &h_restored);
assert!(diff < 1e-4, "restored-prefix forward diverged: {diff}");
// The snapshot must survive restore + forward cycles (deep
// copy of the in-place-mutated GDN state): restore again and
// expect the identical result.
model.restore_kv_cache(&snap).expect("second restore");
let h_again = forward_tokens(&mut model, suffix, prefix.len());
let diff = max_abs_diff(&h_restored, &h_again);
assert!(diff < 1e-6, "second restore diverged: {diff}");
}
/// Restoring must fully replace the live state, not blend with it
/// — a divergent continuation after restore equals the same
/// continuation after a fresh prefill of the prefix.
#[test]
fn restore_replaces_live_state() {
let cfg = tiny_config();
let mut model = tiny_model(&cfg);
let prefix: &[u32] = &[7, 7, 2, 5];
let cont: &[u32] = &[11, 13];
model.clear_kv_cache();
forward_tokens(&mut model, prefix, 0);
let h_fresh = forward_tokens(&mut model, cont, prefix.len());
model.clear_kv_cache();
forward_tokens(&mut model, prefix, 0);
let snap = model.snapshot_kv_cache().expect("snapshot");
forward_tokens(&mut model, &[3, 1, 4, 1, 5], prefix.len());
model.restore_kv_cache(&snap).expect("restore");
let h_restored = forward_tokens(&mut model, cont, prefix.len());
let diff = max_abs_diff(&h_fresh, &h_restored);
assert!(diff < 1e-5, "restore did not replace live state: {diff}");
}
}

View File

@@ -1,843 +0,0 @@
//! Qwen3.6 vision tower.
//!
//! 27 pre-norm ViT blocks with **LayerNorm** (with biases — not the
//! `(1+w)·x` RmsNorm the language model uses), fused QKV attention,
//! GELU-tanh MLP. Followed by a `merger` that LayerNorms each
//! 1152-dim vision token, spatially 2×2-merges them into 4608-dim
//! groups, and projects to the LM's 5120-dim hidden via
//! `linear_fc1 → GELU → linear_fc2`.
//!
//! Architecture spec sourced from beast's cached Qwen3.6-27B
//! safetensors header (Stage A0, see
//! `doc/vision-qwen3_6-spec.md`). All weight shapes confirmed
//! from the live `.safetensors` headers, not inferred.
//!
//! **Conv3d wrinkle.** The published `patch_embed.proj.weight` is 5D
//! `[1152, 3, 2, 16, 16]` — a 3D conv with kernel
//! `(t=2, h=16, w=16)`. Candle 0.10 has no Conv3d. For static images
//! we get away with a trick: when the temporal patch size is 2 and we
//! duplicate the still image along the temporal axis (`T = 2`,
//! frame_0 == frame_1), the Conv3d output equals a Conv2d run with
//! the *sum* of the two temporal weight slices:
//!
//! ```text
//! output = W_0 · frame_0 + W_1 · frame_1 + bias
//! = (W_0 + W_1) · frame + bias (static image)
//! ```
//!
//! So at load we sum-collapse the temporal axis and use a 4D
//! `Conv2d` kernel. Video support would have to do the real Conv3d
//! (different frames mean the trick fails) — tracked alongside the
//! dynamic-resolution work in issue #14.
//!
//! Forward signature (Stage A — no LM splice yet):
//!
//! ```text
//! fn forward(&self, image: &Tensor) -> Result<Tensor>
//! ```
//!
//! `image` is `(3, H, W)` f32, normalised by `preprocess::preprocess`.
//! Returns `(N_lm_tokens, out_hidden_size)` post-merger tokens ready
//! to splice into the LM's input embeddings at `<|image_pad|>`
//! positions. For Qwen3.6 at 448×448 → 28×28 patches → 14×14 = 196
//! LM tokens of dim 5120.
use anyhow::{Context, Result};
use candle_core::{D, DType, Device, IndexOp, Module, Tensor};
use candle_nn::var_builder::ShardedVarBuilder;
use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear};
use serde::Deserialize;
fn env_truthy(name: &str) -> bool {
std::env::var(name)
.map(|v| {
matches!(
v.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "on"
)
})
.unwrap_or(false)
}
/// Legacy escape hatch: when set, use the original Stage-A sequential
/// `pos_embed` lookup instead of the bilinear grid interpolation.
/// Default off (interpolation on) — for A/B comparison only.
fn vision_legacy_pos() -> bool {
env_truthy("NEURON_VISION_LEGACY_POS")
}
/// Legacy escape hatch: when set, skip the 2D vision rotary in the ViT
/// attention (the original Stage-A behaviour). Default off (rotary on)
/// — for A/B comparison only.
fn vision_legacy_rope() -> bool {
env_truthy("NEURON_VISION_LEGACY_ROPE")
}
/// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config`
/// block of `config.json`. Only the fields we actually need are
/// captured; serde tolerates the rest.
#[derive(Debug, Clone, Deserialize)]
pub struct VisionConfig {
/// Number of ViT blocks (`depth: 27` for Qwen3.6).
pub depth: usize,
/// Vision-token dimension throughout the tower (1152 for Qwen3.6).
pub hidden_size: usize,
/// MLP intermediate dim (4304).
pub intermediate_size: usize,
/// Attention head count (16). `head_dim = hidden_size / num_heads`.
pub num_heads: usize,
/// Number of slots in the learned position embedding (2304).
/// Caps the maximum image patch count.
pub num_position_embeddings: usize,
/// Spatial patch edge in pixels (16).
pub patch_size: usize,
/// Temporal kernel depth in the patch embed (2 for Qwen3.6 — we
/// collapse this into a single Conv2d for static-image inference;
/// see the module-level Conv3d wrinkle).
pub temporal_patch_size: usize,
/// Patches grouped per LM token by the merger (2 → 2×2 = 4
/// patches per LM token).
pub spatial_merge_size: usize,
/// Vision input channels (3, RGB).
pub in_channels: usize,
/// Merger output dim — matches the LM's `hidden_size` (5120 for
/// Qwen3.6). The merger projects from vision dim → LM dim.
pub out_hidden_size: usize,
}
const LAYER_NORM_EPS: f64 = 1e-6;
/// Number of LM tokens emitted by the merger per vision-token group.
const LM_TOKENS_PER_MERGE_GROUP: usize = 1;
/// One ViT block: pre-LN → attn → residual; pre-LN → MLP → residual.
struct VisionBlock {
norm1: LayerNorm,
qkv: Linear,
proj: Linear,
norm2: LayerNorm,
fc1: Linear,
fc2: Linear,
num_heads: usize,
head_dim: usize,
}
impl VisionBlock {
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let head_dim = h / cfg.num_heads;
let norm1 = layer_norm(vb.pp("norm1"), h)?;
let qkv = linear(vb.pp("attn.qkv"), h, 3 * h)?;
let proj = linear(vb.pp("attn.proj"), h, h)?;
let norm2 = layer_norm(vb.pp("norm2"), h)?;
let fc1 = linear(vb.pp("mlp.linear_fc1"), h, cfg.intermediate_size)?;
let fc2 = linear(vb.pp("mlp.linear_fc2"), cfg.intermediate_size, h)?;
Ok(Self {
norm1,
qkv,
proj,
norm2,
fc1,
fc2,
num_heads: cfg.num_heads,
head_dim,
})
}
/// `x`: `(N, hidden_size)` un-batched. `rotary`: optional
/// `(cos, sin)` each `(N, head_dim/2)` — the 2D vision rotary applied
/// to q/k. Returns same shape.
fn forward(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
let attn_in = self.norm1.forward(x)?;
let attn_out = self.attention(&attn_in, rotary)?;
let x = x.add(&attn_out)?;
let mlp_in = self.norm2.forward(&x)?;
let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?;
x.add(&mlp_out).map_err(Into::into)
}
/// Multi-head self-attention over the patch sequence. No causal
/// mask — every patch attends to every other patch. When `rotary` is
/// given, the 2D vision rotary (row/col position) is applied to q, k
/// before the scores, matching HF `apply_rotary_pos_emb_vision`
/// (`rope_slow` is the same rotate-half form).
fn attention(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
let (n, hidden) = x.dims2()?;
// qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden).
let qkv = self.qkv.forward(x)?;
let qkv = qkv.reshape((n, 3, self.num_heads, self.head_dim))?;
// Transpose to (3, num_heads, N, head_dim) for per-head views.
let qkv = qkv.permute((1, 2, 0, 3))?.contiguous()?;
let q = qkv.i(0)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
// 2D vision rotary on q, k (full head_dim; rotate-half form).
let (q, k) = match rotary {
Some((cos, sin)) => {
let q = candle_nn::rotary_emb::rope_slow(&q.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
let k = candle_nn::rotary_emb::rope_slow(&k.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
(q, k)
}
None => (q, k),
};
let scale = 1.0 / (self.head_dim as f64).sqrt();
// (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N)
let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
let scores = (scores * scale)?;
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
// (num_heads, N, N) @ (num_heads, N, head_dim) -> (num_heads, N, head_dim)
let out = probs.matmul(&v)?;
// Merge heads back: (N, num_heads, head_dim) -> (N, hidden).
let out = out.permute((1, 0, 2))?.contiguous()?.reshape((n, hidden))?;
self.proj.forward(&out).map_err(Into::into)
}
}
/// `merger`: LayerNorm per token → spatial 2×2 merge (concat 4
/// adjacent tokens into one 4608-dim vector) → fc1 → GELU-tanh →
/// fc2. Output dim is the LM's hidden_size.
struct VisionMerger {
norm: LayerNorm,
fc1: Linear,
fc2: Linear,
merge_input_dim: usize,
spatial_merge_size: usize,
}
impl VisionMerger {
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let merge = cfg.spatial_merge_size;
let merge_input_dim = h * merge * merge;
let norm = layer_norm(vb.pp("norm"), h)?;
let fc1 = linear(vb.pp("linear_fc1"), merge_input_dim, merge_input_dim)?;
let fc2 = linear(vb.pp("linear_fc2"), merge_input_dim, cfg.out_hidden_size)?;
Ok(Self {
norm,
fc1,
fc2,
merge_input_dim,
spatial_merge_size: merge,
})
}
/// `tokens`: `(grid_h, grid_w, hidden_size)`. The merger reshapes
/// each `merge×merge` block of adjacent patches into a single
/// concatenated vector, then projects.
///
/// `grid_h` and `grid_w` must both be multiples of
/// `spatial_merge_size`. Returns
/// `(grid_h/merge × grid_w/merge, out_hidden_size)`.
fn forward(&self, tokens: &Tensor) -> Result<Tensor> {
let (gh, gw, h) = tokens.dims3()?;
let m = self.spatial_merge_size;
anyhow::ensure!(
gh.is_multiple_of(m) && gw.is_multiple_of(m),
"merger expects spatial dims divisible by merge_size={m}; got ({gh}, {gw})"
);
let tokens = self.norm.forward(tokens)?;
// (gh, gw, h) -> (gh/m, m, gw/m, m, h) -> (gh/m, gw/m, m, m, h)
// -> flatten last three -> (gh/m, gw/m, m*m*h) -> (N_lm, merge_input_dim)
let out_h = gh / m;
let out_w = gw / m;
let merged = tokens
.reshape((out_h, m, out_w, m, h))?
.permute((0, 2, 1, 3, 4))?
.contiguous()?
.reshape((out_h * out_w, self.merge_input_dim))?;
let hidden = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&merged)?)?)?;
Ok(hidden)
}
}
/// 2D rotary position embedding for the vision tower. Each patch's
/// `head_dim` rotates by its `(row, col)` grid coordinates: the first
/// half of the rotary freqs are driven by the row position, the second
/// half by the column. Mirrors HF `Qwen3VLVisionRotaryEmbedding` +
/// `rot_pos_emb` (θ = 10000, `dim = head_dim/2`).
struct VisionRotaryEmbedding {
/// `(half,)` f32, `half = head_dim/4` freqs per spatial axis.
inv_freq: Vec<f32>,
}
impl VisionRotaryEmbedding {
fn new(head_dim: usize) -> Self {
// HF: Qwen3VLVisionRotaryEmbedding(head_dim // 2), theta 10000.
let dim = head_dim / 2;
let theta = 10000f32;
let inv_freq = (0..dim)
.step_by(2)
.map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
.collect();
Self { inv_freq }
}
/// cos/sin for a `gh×gw` patch grid in **row-major** order. Returns
/// `(cos, sin)` each `(gh*gw, head_dim/2)`: per patch, the row-axis
/// freqs `row·inv_freq` followed by the col-axis freqs `col·inv_freq`
/// (then `rope_slow` duplicates them across the full head_dim).
fn cos_sin(
&self,
gh: usize,
gw: usize,
dev: &Device,
dtype: DType,
) -> candle_core::Result<(Tensor, Tensor)> {
let half = self.inv_freq.len();
let n = gh * gw;
let mut data = Vec::with_capacity(n * 2 * half);
for hi in 0..gh {
for wi in 0..gw {
for &f in &self.inv_freq {
data.push(hi as f32 * f);
}
for &f in &self.inv_freq {
data.push(wi as f32 * f);
}
}
}
let freqs = Tensor::from_vec(data, (n, 2 * half), dev)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
Ok((cos, sin))
}
}
/// The vision tower itself.
pub struct VisionTower {
/// Sum-collapsed temporal kernel (Conv2d, see module doc).
patch_embed: Conv2d,
pos_embed: Embedding,
rotary: VisionRotaryEmbedding,
blocks: Vec<VisionBlock>,
merger: VisionMerger,
config: VisionConfig,
dtype: DType,
device: Device,
}
impl VisionTower {
/// Load from a `ShardedVarBuilder` rooted at the safetensors
/// `model.visual.` prefix. Caller is responsible for the `pp` —
/// see `Qwen3_5ForCausalLM::new` (Stage A4).
pub fn load(cfg: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
let dtype = vb.dtype();
let device = vb.device().clone();
// patch_embed.proj is published as 5D Conv3d weight; we
// sum-collapse the temporal axis (size = temporal_patch_size)
// to get a 4D Conv2d kernel. This is exact for the static-
// image case where T = temporal_patch_size frames are
// identical (i.e. the input was duplicated along T).
let raw_weight = vb
.pp("patch_embed.proj")
.get(
(
cfg.hidden_size,
cfg.in_channels,
cfg.temporal_patch_size,
cfg.patch_size,
cfg.patch_size,
),
"weight",
)
.context("load model.visual.patch_embed.proj.weight (5D Conv3d kernel)")?;
// Sum along the temporal axis (dim 2) — see module doc-comment.
let folded = raw_weight.sum(2)?; // -> (hidden, in_channels, patch, patch)
let proj_bias = vb
.pp("patch_embed.proj")
.get(cfg.hidden_size, "bias")
.context("load model.visual.patch_embed.proj.bias")?;
let conv_cfg = Conv2dConfig {
stride: cfg.patch_size,
..Default::default()
};
let patch_embed = Conv2d::new(folded, Some(proj_bias), conv_cfg);
let pos_embed_weight = vb
.pp("pos_embed")
.get((cfg.num_position_embeddings, cfg.hidden_size), "weight")
.context("load model.visual.pos_embed.weight")?;
let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size);
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
let blocks_vb = vb.pp("blocks");
let mut blocks = Vec::with_capacity(cfg.depth);
for i in 0..cfg.depth {
blocks.push(
VisionBlock::load(&cfg, &blocks_vb.pp(i))
.with_context(|| format!("load vision block {i}"))?,
);
}
let merger = VisionMerger::load(&cfg, &vb.pp("merger")).context("load vision merger")?;
Ok(Self {
patch_embed,
pos_embed,
rotary,
blocks,
merger,
config: cfg,
dtype,
device,
})
}
pub fn config(&self) -> &VisionConfig {
&self.config
}
/// Number of LM tokens this tower emits for an `(H, W)` pixel
/// image after the merger. Equal to
/// `(H / patch_size / spatial_merge_size) * (W / patch_size / spatial_merge_size)`.
pub fn lm_tokens_for(&self, h: u32, w: u32) -> usize {
let m = self.config.spatial_merge_size;
let patch = self.config.patch_size;
let gh = (h as usize) / patch / m;
let gw = (w as usize) / patch / m;
gh * gw * LM_TOKENS_PER_MERGE_GROUP
}
/// Bilinearly interpolate the learned `pos_embed` grid (a
/// `num_grid_per_side × num_grid_per_side` table, 48×48 for Qwen3.6)
/// onto the actual `gh × gw` patch grid, in **row-major** patch
/// order. Port of the HF `fast_pos_embed_interpolate`: for each patch
/// at fractional grid coord `(linspace(0, ngrid-1, gh)[hi],
/// linspace(0, ngrid-1, gw)[wi])`, blend the 4 surrounding grid
/// entries by bilinear weights. Returns `(gh*gw, hidden)` in
/// `self.dtype`.
fn interpolated_pos_embed(&self, gh: usize, gw: usize) -> Result<Tensor> {
let ngrid = (self.config.num_position_embeddings as f64).sqrt().round() as usize;
anyhow::ensure!(
ngrid * ngrid == self.config.num_position_embeddings,
"num_position_embeddings {} is not a perfect square",
self.config.num_position_embeddings
);
// Evenly-spaced fractional indices into the [0, ngrid-1] grid.
let lin = |n: usize| -> Vec<f64> {
if n <= 1 {
vec![0.0]
} else {
let step = (ngrid - 1) as f64 / (n - 1) as f64;
(0..n).map(|i| i as f64 * step).collect()
}
};
let hs = lin(gh);
let ws = lin(gw);
let n = gh * gw;
// Four corner index sets + bilinear weight sets, row-major.
let mut idx: [Vec<u32>; 4] = [
Vec::with_capacity(n),
Vec::with_capacity(n),
Vec::with_capacity(n),
Vec::with_capacity(n),
];
let mut wts: [Vec<f32>; 4] = [
Vec::with_capacity(n),
Vec::with_capacity(n),
Vec::with_capacity(n),
Vec::with_capacity(n),
];
for &hv in &hs {
let hf = hv as usize; // floor (hv >= 0)
let hc = (hf + 1).min(ngrid - 1);
let dh = (hv - hf as f64) as f32;
for &wv in &ws {
let wf = wv as usize;
let wc = (wf + 1).min(ngrid - 1);
let dw = (wv - wf as f64) as f32;
idx[0].push((hf * ngrid + wf) as u32);
wts[0].push((1.0 - dh) * (1.0 - dw));
idx[1].push((hf * ngrid + wc) as u32);
wts[1].push((1.0 - dh) * dw);
idx[2].push((hc * ngrid + wf) as u32);
wts[2].push(dh * (1.0 - dw));
idx[3].push((hc * ngrid + wc) as u32);
wts[3].push(dh * dw);
}
}
// Blend in f32 and cast once at the end — the reference keeps
// the bilinear weights f32 against bf16 table rows; rounding
// the weights to bf16 first costs a visible slice of fixture
// parity (#15).
let mut acc: Option<Tensor> = None;
for corner in 0..4 {
let idx_t = Tensor::from_vec(std::mem::take(&mut idx[corner]), (n,), &self.device)?;
let emb = self
.pos_embed
.forward(&idx_t)?
.to_dtype(candle_core::DType::F32)?; // (n, hidden)
let wt = Tensor::from_vec(std::mem::take(&mut wts[corner]), (n, 1), &self.device)?;
let term = emb.broadcast_mul(&wt)?;
acc = Some(match acc {
Some(a) => a.add(&term)?,
None => term,
});
}
acc.expect("4 corners accumulated")
.to_dtype(self.dtype)
.map_err(Into::into)
}
/// Encode one image.
///
/// `image`: row-major `(3, H, W)` f32 tensor on `self.device`,
/// already normalised by `preprocess::preprocess`. Both `H` and
/// `W` must be multiples of `patch_size * spatial_merge_size`.
///
/// Returns `(N_lm, out_hidden_size)` — LM-side image tokens
/// ready to splice into the language model's input embeddings.
pub fn forward(&self, image: &Tensor) -> Result<Tensor> {
let (c, h, w) = image.dims3()?;
anyhow::ensure!(
c == self.config.in_channels,
"image must have {} channels, got {c}",
self.config.in_channels
);
let patch = self.config.patch_size;
anyhow::ensure!(
h.is_multiple_of(patch) && w.is_multiple_of(patch),
"image dims must be multiples of patch_size={patch}; got ({h}, {w})"
);
let gh = h / patch;
let gw = w / patch;
let n_patches = gh * gw;
anyhow::ensure!(
n_patches <= self.config.num_position_embeddings,
"patch count {n_patches} exceeds pos_embed budget {}",
self.config.num_position_embeddings
);
// Add batch axis for conv: (1, 3, H, W) → (1, hidden, gh, gw)
// → (hidden, gh, gw) → permute to (gh, gw, hidden) → flatten to (N, hidden)
let x = image.unsqueeze(0)?.to_dtype(self.dtype)?;
let x = self.patch_embed.forward(&x)?;
let x = x.squeeze(0)?;
let x = x.permute((1, 2, 0))?.contiguous()?;
let x = x.reshape((n_patches, self.config.hidden_size))?;
// Learned absolute position embeddings. The `pos_embed` table is
// a `num_position_embeddings = num_grid_per_side²` learned grid
// (48×48 for Qwen3.6); for a `gh×gw` patch grid the reference
// (`fast_pos_embed_interpolate`) bilinearly interpolates that
// grid to `gh×gw`. The legacy path (a naive sequential lookup of
// the first `n_patches` rows) mis-maps the grid stride and
// scrambles spatial structure — kept only behind
// `NEURON_VISION_LEGACY_POS=1` for A/B comparison.
let pos = if vision_legacy_pos() {
let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?;
self.pos_embed.forward(&positions)?
} else {
self.interpolated_pos_embed(gh, gw)?
};
let mut x = x.add(&pos)?;
// 2D vision rotary (row/col per patch), computed once and applied
// in every block's attention. Legacy escape hatch skips it.
let rotary = if vision_legacy_rope() {
None
} else {
Some(self.rotary.cos_sin(gh, gw, &self.device, self.dtype)?)
};
let rotary_ref = rotary.as_ref();
for (i, block) in self.blocks.iter().enumerate() {
x = block
.forward(&x, rotary_ref)
.with_context(|| format!("vision block {i}"))?;
}
// (n_patches, hidden) → (gh, gw, hidden) for the merger.
let x = x.reshape((gh, gw, self.config.hidden_size))?;
self.merger.forward(&x)
}
}
/// Manually load a candle_nn LayerNorm from a ShardedVarBuilder.
/// candle_nn's `layer_norm` builder takes `crate::VarBuilder`, not
/// `ShardedVarBuilder`, so the existing arch modules in this crate
/// uniformly do the manual load + struct construction pattern (see
/// `full_attn::load_linear_no_bias`). We follow suit here.
fn layer_norm(vb: ShardedVarBuilder, size: usize) -> Result<LayerNorm> {
let weight = vb
.get(size, "weight")
.with_context(|| format!("load LayerNorm.weight at '{}'", vb.prefix()))?;
let bias = vb
.get(size, "bias")
.with_context(|| format!("load LayerNorm.bias at '{}'", vb.prefix()))?;
Ok(LayerNorm::new(weight, bias, LAYER_NORM_EPS))
}
/// Manually load a candle_nn Linear (with bias) from a
/// ShardedVarBuilder. Same rationale as `layer_norm` above.
fn linear(vb: ShardedVarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
let weight = vb
.get((out_dim, in_dim), "weight")
.with_context(|| format!("load Linear.weight at '{}'", vb.prefix()))?;
let bias = vb
.get(out_dim, "bias")
.with_context(|| format!("load Linear.bias at '{}'", vb.prefix()))?;
Ok(Linear::new(weight, Some(bias)))
}
/// PyTorch's `gelu_pytorch_tanh` approximation — what the Qwen3.6
/// vision tower's `hidden_act` specifies. candle's `Tensor::gelu`
/// uses the exact erf-based GELU, so we compute the tanh
/// approximation explicitly:
///
/// ```text
/// gelu_tanh(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
/// ```
fn gelu_tanh(x: &Tensor) -> Result<Tensor> {
// sqrt(2 / pi) = 0.7978845608028654
const COEFF: f64 = 0.7978845608028654;
const KAPPA: f64 = 0.044715;
let x3 = x.powf(3.0)?;
let inner = (x + (x3 * KAPPA)?)?;
let inner = (inner * COEFF)?;
let t = inner.tanh()?;
let one_plus_t = (t + 1.0)?;
let out = (x * 0.5)?;
let out = out.broadcast_mul(&one_plus_t)?;
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
/// Build a tiny VisionConfig usable on CPU with random weights.
/// Match the Qwen3.6 shape relations (depth-N stack, hidden mod
/// num_heads, intermediate_size > hidden_size) but with small
/// dims so tests run in milliseconds.
fn tiny_config() -> VisionConfig {
VisionConfig {
depth: 2,
hidden_size: 32,
intermediate_size: 64,
num_heads: 4,
num_position_embeddings: 64,
patch_size: 4,
temporal_patch_size: 2,
spatial_merge_size: 2,
in_channels: 3,
out_hidden_size: 48,
}
}
/// Hand-construct a VisionTower with random weights. This is the
/// same trick `linear_attn::tests::forward_smoke_with_tiny_dimensions`
/// uses — bypass the safetensors-backed `ShardedVarBuilder` path
/// (which can't be built from in-memory tensors) and assemble the
/// struct fields directly. The real `VisionTower::load` is
/// exercised by the cuda-integration smoke test in Stage A6.
fn tiny_tower(cfg: &VisionConfig) -> VisionTower {
let device = Device::Cpu;
let dtype = DType::F32;
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &device).unwrap();
let ones = |shape: &[usize]| Tensor::ones(shape, dtype, &device).unwrap();
let randn = |shape: &[usize]| Tensor::randn(0_f32, 0.02, shape, &device).unwrap();
let patch_embed = Conv2d::new(
randn(&[
cfg.hidden_size,
cfg.in_channels,
cfg.patch_size,
cfg.patch_size,
]),
Some(zeros(&[cfg.hidden_size])),
Conv2dConfig {
stride: cfg.patch_size,
..Default::default()
},
);
let pos_embed = Embedding::new(
randn(&[cfg.num_position_embeddings, cfg.hidden_size]),
cfg.hidden_size,
);
let mut blocks = Vec::with_capacity(cfg.depth);
for _ in 0..cfg.depth {
let head_dim = cfg.hidden_size / cfg.num_heads;
blocks.push(VisionBlock {
norm1: LayerNorm::new(
ones(&[cfg.hidden_size]),
zeros(&[cfg.hidden_size]),
LAYER_NORM_EPS,
),
qkv: Linear::new(
randn(&[3 * cfg.hidden_size, cfg.hidden_size]),
Some(zeros(&[3 * cfg.hidden_size])),
),
proj: Linear::new(
randn(&[cfg.hidden_size, cfg.hidden_size]),
Some(zeros(&[cfg.hidden_size])),
),
norm2: LayerNorm::new(
ones(&[cfg.hidden_size]),
zeros(&[cfg.hidden_size]),
LAYER_NORM_EPS,
),
fc1: Linear::new(
randn(&[cfg.intermediate_size, cfg.hidden_size]),
Some(zeros(&[cfg.intermediate_size])),
),
fc2: Linear::new(
randn(&[cfg.hidden_size, cfg.intermediate_size]),
Some(zeros(&[cfg.hidden_size])),
),
num_heads: cfg.num_heads,
head_dim,
});
}
let merge_input_dim = cfg.hidden_size * cfg.spatial_merge_size * cfg.spatial_merge_size;
let merger = VisionMerger {
norm: LayerNorm::new(
ones(&[cfg.hidden_size]),
zeros(&[cfg.hidden_size]),
LAYER_NORM_EPS,
),
fc1: Linear::new(
randn(&[merge_input_dim, merge_input_dim]),
Some(zeros(&[merge_input_dim])),
),
fc2: Linear::new(
randn(&[cfg.out_hidden_size, merge_input_dim]),
Some(zeros(&[cfg.out_hidden_size])),
),
merge_input_dim,
spatial_merge_size: cfg.spatial_merge_size,
};
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
VisionTower {
patch_embed,
pos_embed,
rotary,
blocks,
merger,
config: cfg.clone(),
dtype,
device,
}
}
#[test]
fn forward_with_random_weights_produces_finite_output() {
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
// 16×16 image at patch_size=4 → 4×4 patches → after 2×2
// merge → 2×2 = 4 LM tokens of dim out_hidden_size.
let image = Tensor::randn(0_f32, 1.0, (3, 16, 16), &Device::Cpu).unwrap();
let out = tower.forward(&image).expect("forward");
let (n_lm, hidden) = out.dims2().unwrap();
assert_eq!(n_lm, 4);
assert_eq!(hidden, cfg.out_hidden_size);
// No NaN/Inf
let values: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
assert!(
values.iter().all(|v| v.is_finite()),
"forward must produce finite values"
);
}
#[test]
fn interpolated_pos_embed_reduces_to_sequential_at_native_grid() {
// When the patch grid equals the pos_embed grid (gh=gw=ngrid),
// linspace(0,ngrid-1,ngrid) is the integer ladder, so every patch
// lands exactly on a grid node (dh=dw=0, corner-0 weight 1) and
// the bilinear result is the raw pos_embed rows in row-major
// order — i.e. identical to the legacy sequential lookup.
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
let ngrid = (cfg.num_position_embeddings as f64).sqrt() as usize; // 8
let interp = tower.interpolated_pos_embed(ngrid, ngrid).unwrap();
let seq = tower
.pos_embed
.forward(&Tensor::arange(0u32, (ngrid * ngrid) as u32, &Device::Cpu).unwrap())
.unwrap();
let a: Vec<f32> = interp.flatten_all().unwrap().to_vec1().unwrap();
let b: Vec<f32> = seq.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(a.len(), b.len());
for (x, y) in a.iter().zip(b.iter()) {
assert!((x - y).abs() < 1e-5, "interp {x} vs seq {y}");
}
}
#[test]
fn vision_rotary_row_col_structure() {
// head_dim 8 → rotary dim 4 → inv_freq over [0,2] → 2 freqs/axis.
let rot = VisionRotaryEmbedding::new(8);
assert_eq!(rot.inv_freq.len(), 2);
let (cos, sin) = rot.cos_sin(2, 2, &Device::Cpu, DType::F32).unwrap();
assert_eq!(cos.dims(), &[4, 4]); // 4 patches, head_dim/2 = 4 cols
// Patch (0,0): all freqs 0 → cos 1, sin 0.
let s0: Vec<f32> = sin.i(0).unwrap().to_vec1().unwrap();
assert!(s0.iter().all(|&s| s.abs() < 1e-6));
// Patch index 2 = grid (1,0): row=1 drives the first half, col=0
// leaves the second half at zero.
let s2: Vec<f32> = sin.i(2).unwrap().to_vec1().unwrap();
assert!(s2[0].abs() > 1e-6, "row half must be non-zero");
assert!(
s2[2].abs() < 1e-6 && s2[3].abs() < 1e-6,
"col half must be zero"
);
}
#[test]
fn lm_token_count_matches_grid() {
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
// 16x16 image → 4x4 patches → 2x2 = 4 LM tokens
assert_eq!(tower.lm_tokens_for(16, 16), 4);
// 32x32 image → 8x8 patches → 4x4 = 16 LM tokens
assert_eq!(tower.lm_tokens_for(32, 32), 16);
}
#[test]
fn rejects_image_with_dims_not_multiple_of_patch() {
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
let image = Tensor::randn(0_f32, 1.0, (3, 17, 17), &Device::Cpu).unwrap();
let err = tower.forward(&image).unwrap_err();
assert!(format!("{err:#}").contains("patch_size"));
}
#[test]
fn rejects_image_with_wrong_channel_count() {
let cfg = tiny_config();
let tower = tiny_tower(&cfg);
let image = Tensor::randn(0_f32, 1.0, (4, 16, 16), &Device::Cpu).unwrap();
let err = tower.forward(&image).unwrap_err();
assert!(format!("{err:#}").contains("channels"));
}
#[test]
fn gelu_tanh_matches_known_values() {
// Reference values for gelu_pytorch_tanh from PyTorch:
// gelu_tanh(0.0) = 0.0
// gelu_tanh(1.0) ≈ 0.8411920071
// gelu_tanh(-1.0) ≈ -0.1588079929
let x = Tensor::new(&[0.0_f32, 1.0, -1.0], &Device::Cpu).unwrap();
let y = gelu_tanh(&x).unwrap();
let v: Vec<f32> = y.to_vec1().unwrap();
assert!((v[0]).abs() < 1e-6, "gelu_tanh(0) ≈ 0, got {}", v[0]);
assert!(
(v[1] - 0.841_192_f32).abs() < 1e-5,
"gelu_tanh(1) ≈ 0.84119, got {}",
v[1]
);
assert!(
(v[2] - -0.158_808_f32).abs() < 1e-5,
"gelu_tanh(-1) ≈ -0.15881, got {}",
v[2]
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,562 +0,0 @@
//! Chat-template rendering for the model-supplied Jinja templates
//! HuggingFace tokenizers ship in `tokenizer_config.json`.
//!
//! ## Background
//!
//! Every modern open-weight model bundles a `chat_template` field
//! in its `tokenizer_config.json` — a Jinja2 template string that
//! converts a sequence of `{role, content}` messages into the
//! exact prompt the model was trained on. Examples:
//!
//! - Qwen3-Coder: `<|im_start|>{role}\n{content}<|im_end|>\n…`
//! with conditional `enable_thinking` handling that injects an
//! empty `<think>\n\n</think>` block when set false.
//! - DeepSeek-R1: similar im_start framing with different special-
//! token names.
//! - Mistral / Magistral: a `[INST]` / `[/INST]` framing.
//! - Claude / Llama: another shape again.
//!
//! Rendering the model's own template is the only way to get the
//! *exact* prompt format the model was trained on plus the
//! model-specific kwargs (`enable_thinking`, `tools`, …) without
//! hardcoding per-model logic. The alternative — neuron's previous
//! `format_qwen3_prompt` — was a hardcoded Qwen3 ChatML glue that
//! ignored kwargs entirely.
//!
//! ## Scope
//!
//! This module is request-side only: it builds the prompt string
//! the tokenizer ingests before inference. The reasoning- and
//! tool-call-marker token routing (issues #6, #8) is response-side
//! and stays in `wire::openai_chat` / the streaming inference
//! loops.
//!
//! ## Fallback
//!
//! When the model's `tokenizer_config.json` is missing, doesn't
//! parse, lacks a `chat_template`, or renders an error, the caller
//! falls back to `format_qwen3_prompt`. The
//! `NEURON_USE_CHAT_TEMPLATE=false` env var is a global kill
//! switch — if a deploy goes sideways and the renderer is to
//! blame, an operator can flip the env and restart neuron without
//! shipping a new build.
use anyhow::{Context, Result};
use cortex_core::openai::{ChatMessage, MessageContent};
use minijinja::{Environment, Error as MjError, ErrorKind as MjErrorKind, Value as MjValue};
use serde_json::Value;
use std::path::Path;
/// Environment variable that, when set to `false`/`0`/`no`,
/// forces every model to skip its `chat_template` and fall back
/// to `format_qwen3_prompt`. Default (unset) is "use chat
/// templates where available".
pub const KILL_SWITCH_ENV: &str = "NEURON_USE_CHAT_TEMPLATE";
/// Read the global kill switch. `true` means chat templates are
/// enabled; `false` forces the fallback path everywhere.
pub fn chat_templates_enabled() -> bool {
match std::env::var(KILL_SWITCH_ENV).ok().as_deref() {
Some(s) => !matches!(
s.trim().to_ascii_lowercase().as_str(),
"false" | "0" | "no" | "off"
),
None => true,
}
}
/// Probe for the model's chat template in the same directory the
/// tokenizer was loaded from, following HuggingFace `transformers`
/// precedence: a standalone `chat_template.jinja` (then
/// `chat_template.json`) wins over the `chat_template` field in
/// `tokenizer_config.json`.
///
/// This matters for multimodal models: Qwen3-VL / Qwen3.6 ship their
/// vision-aware template (the one that emits
/// `<|vision_start|><|image_pad|><|vision_end|>` per image) **only** in
/// `chat_template.jinja`, and may not ship a `tokenizer_config.json` at
/// all. Reading `tokenizer_config.json` alone returned `None`, which
/// dropped image content into the text-only `format_qwen3_prompt`
/// fallback — so image requests rendered zero `<|image_pad|>` tokens
/// and the vision path bailed on the count mismatch.
pub fn load_chat_template_alongside(tokenizer_json_path: &Path) -> Option<String> {
let parent = tokenizer_json_path.parent()?;
// 1. Standalone Jinja file — raw template text, highest priority.
let jinja_path = parent.join("chat_template.jinja");
match std::fs::read_to_string(&jinja_path) {
Ok(text) if !text.trim().is_empty() => {
tracing::info!(
path = %jinja_path.display(),
"chat_template: loaded standalone chat_template.jinja"
);
return Some(text);
}
Ok(_) => {
tracing::warn!(
path = %jinja_path.display(),
"chat_template: chat_template.jinja present but empty; trying other sources"
);
}
Err(_) => {} // absent — fall through, common case
}
// 2. Standalone JSON file — `{"chat_template": "..."}` form.
let json_path = parent.join("chat_template.json");
if json_path.exists()
&& let Some(t) = load_chat_template_from(&json_path)
{
tracing::info!(
path = %json_path.display(),
"chat_template: loaded standalone chat_template.json"
);
return Some(t);
}
// 3. The `chat_template` field inside tokenizer_config.json.
let config_path = parent.join("tokenizer_config.json");
load_chat_template_from(&config_path)
}
/// Best-effort load of `chat_template` from a HuggingFace
/// `tokenizer_config.json`. Returns `None` when the file is
/// absent, doesn't parse, or lacks the `chat_template` field —
/// in all of those cases the caller falls back to
/// `format_qwen3_prompt`. Warnings are logged so an operator can
/// see why the fallback fired.
pub fn load_chat_template_from(path: &Path) -> Option<String> {
let text = match std::fs::read_to_string(path) {
Ok(t) => t,
Err(e) => {
tracing::debug!(
path = %path.display(),
error = %e,
"chat_template: tokenizer_config.json absent or unreadable; falling back"
);
return None;
}
};
let value: Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
path = %path.display(),
error = %e,
"chat_template: tokenizer_config.json failed to parse; falling back"
);
return None;
}
};
// Some tokenizer_config.json files carry `chat_template` as an
// array of `{name, template}` objects (multi-template models —
// tool-use variant, default variant). For now we pick the first
// entry; future iterations could honour a name hint.
match value.get("chat_template") {
Some(Value::String(s)) => Some(s.clone()),
Some(Value::Array(arr)) => {
for entry in arr {
if let Some(t) = entry.get("template").and_then(|v| v.as_str()) {
return Some(t.to_string());
}
}
tracing::warn!(
path = %path.display(),
"chat_template: array form had no usable template entry; falling back"
);
None
}
_ => None,
}
}
/// Render the chat template into the prompt the model expects.
///
/// `template` is the raw Jinja string from `tokenizer_config.json`.
/// `messages` is the conversation in order. `kwargs` is the
/// `chat_template_kwargs` object the client supplied on the
/// request (or `Value::Null` when absent). The function expands
/// the kwargs into the Jinja context alongside the standard
/// `messages` and `add_generation_prompt` variables HF templates
/// expect.
///
/// `tools` is the request's `tools` array (or `Value::Null`).
/// Some chat templates iterate it to emit native tool definitions
/// (Qwen3-Coder's tool-use template, Mistral's [TOOL_DEFINITIONS]
/// frame). We forward whatever the client sent without
/// interpretation.
pub fn render_chat_template(
template: &str,
messages: &[ChatMessage],
tools: &Value,
kwargs: &Value,
) -> Result<String> {
let mut env = Environment::new();
// HF chat templates are authored against Python's Jinja2 with its
// string semantics. Bridge the two so real model templates render:
//
// - `pycompat::unknown_method_callback` supplies Python str/list/dict
// methods minijinja lacks natively (`startswith`, `endswith`,
// `split`, `rstrip`, `lstrip`, …) — the Qwen3.6 template uses
// several in its think-block and tool-response handling.
// - `raise_exception` is the global HF templates call to reject
// malformed inputs (e.g. an image in a system message). Map it to
// a render error so the caller falls back / surfaces it.
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
env.add_function(
"raise_exception",
|msg: String| -> Result<MjValue, MjError> {
Err(MjError::new(MjErrorKind::InvalidOperation, msg))
},
);
// Compile the template against a fixed name so error messages
// surface "chat_template" rather than `<template>`.
env.add_template("chat_template", template)
.context("compile chat_template")?;
let tmpl = env.get_template("chat_template").unwrap();
// Convert our internal ChatMessage shape into the
// `[{role, content}]` shape HF templates iterate. Text content
// becomes a string; Parts becomes an array of content blocks.
// The HF templates handle both shapes via `content is string`
// checks or content-array iteration.
let messages_json: Vec<Value> = messages
.iter()
.map(|m| {
let content_value = match &m.content {
MessageContent::Text(s) => Value::String(s.clone()),
MessageContent::Parts(parts) => Value::Array(parts.clone()),
};
let mut obj = serde_json::Map::new();
obj.insert("role".into(), Value::String(m.role.clone()));
obj.insert("content".into(), content_value);
// Forward extras (e.g. tool_calls on assistant turns,
// tool_call_id on tool result turns). HF templates that
// need them read e.g. `message.tool_calls`.
if let Value::Object(extras) = &m.extra {
for (k, v) in extras {
obj.insert(k.clone(), v.clone());
}
}
Value::Object(obj)
})
.collect();
// Build the kwargs context. Add base bindings the template
// expects (`messages`, `add_generation_prompt`, `tools`) plus
// anything the caller passed in `chat_template_kwargs`. Caller
// kwargs override the defaults so `add_generation_prompt: false`
// from the request actually wins.
let mut ctx_map = serde_json::Map::new();
ctx_map.insert("messages".into(), Value::Array(messages_json));
ctx_map.insert("add_generation_prompt".into(), Value::Bool(true));
if !tools.is_null() {
ctx_map.insert("tools".into(), tools.clone());
}
if let Value::Object(kwargs_obj) = kwargs {
for (k, v) in kwargs_obj {
ctx_map.insert(k.clone(), v.clone());
}
}
// `Template::render` takes any Serialize value; serde_json's
// `Value` implements it natively, so we pass the assembled
// context object directly without going through the
// `context!` macro (which expects minijinja-native values).
tmpl.render(Value::Object(ctx_map))
.context("render chat_template")
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
/// Reproduces the Qwen3.6 vision template's image-insertion
/// condition against the OpenAI `image_url` content-part shape our
/// renderer forwards. Confirms minijinja's `'image_url' in item`
/// matches a serde_json object that carries that key — i.e. the
/// template *can* emit `<|image_pad|>` for our parts.
#[test]
fn image_url_part_renders_image_pad() {
// Condition copied from doc/vision-qwen3_6-spec.md (lines 8-18
// of the real chat_template.jinja).
let template = "{%- for message in messages -%}\
{%- if message.content is string -%}\
{{ message.content }}\
{%- else -%}\
{%- for item in message.content -%}\
{%- if 'image' in item or 'image_url' in item or item.type == 'image' -%}\
<|vision_start|><|image_pad|><|vision_end|>\
{%- elif item.type == 'text' -%}\
{{ item.text }}\
{%- endif -%}\
{%- endfor -%}\
{%- endif -%}\
{%- endfor -%}";
let messages = vec![ChatMessage {
role: "user".into(),
content: MessageContent::Parts(vec![
json!({"type": "text", "text": "what is this?"}),
json!({"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA="}}),
]),
extra: Value::Object(Default::default()),
}];
let out = render_chat_template(template, &messages, &Value::Null, &Value::Null)
.expect("render should succeed");
assert!(
out.contains("<|image_pad|>"),
"expected the image_url part to emit <|image_pad|>; rendered: {out:?}"
);
}
/// `chat_template.jinja` must win over `tokenizer_config.json`'s
/// `chat_template` field — the transformers precedence Qwen3.6
/// relies on (its vision template ships only in the `.jinja` file).
#[test]
fn standalone_jinja_template_takes_precedence() {
let dir = std::env::temp_dir().join(format!(
"neuron_ct_precedence_{}_{}",
std::process::id(),
line!()
));
std::fs::create_dir_all(&dir).unwrap();
std::fs::write(dir.join("chat_template.jinja"), "FROM_JINJA").unwrap();
std::fs::write(
dir.join("tokenizer_config.json"),
r#"{"chat_template": "FROM_CONFIG"}"#,
)
.unwrap();
// tokenizer_json_path is the sibling the loader takes a parent of.
let got = load_chat_template_alongside(&dir.join("tokenizer.json"));
std::fs::remove_dir_all(&dir).ok();
assert_eq!(got.as_deref(), Some("FROM_JINJA"));
}
/// With no standalone file, fall back to the tokenizer_config.json
/// field — the text-only path stays unchanged.
#[test]
fn falls_back_to_tokenizer_config_when_no_standalone() {
let dir = std::env::temp_dir().join(format!(
"neuron_ct_fallback_{}_{}",
std::process::id(),
line!()
));
std::fs::create_dir_all(&dir).unwrap();
std::fs::write(
dir.join("tokenizer_config.json"),
r#"{"chat_template": "FROM_CONFIG"}"#,
)
.unwrap();
let got = load_chat_template_alongside(&dir.join("tokenizer.json"));
std::fs::remove_dir_all(&dir).ok();
assert_eq!(got.as_deref(), Some("FROM_CONFIG"));
}
/// The *actual* Qwen3.6-27B `chat_template.jinja` (verbatim from
/// beast's HF cache) must render in minijinja and emit exactly one
/// `<|image_pad|>` for a text+image user turn. This is the real
/// end-to-end check the unit tests above only approximate — it
/// catches any minijinja incompatibility (namespace, macros,
/// reverse slice, string methods) before it reaches production.
#[test]
fn real_qwen3_6_template_renders_one_image_pad() {
let template = include_str!("testdata/qwen3_6_chat_template.jinja");
let messages = vec![ChatMessage {
role: "user".into(),
content: MessageContent::Parts(vec![
json!({"type": "text", "text": "what is this?"}),
json!({"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA="}}),
]),
extra: Value::Object(Default::default()),
}];
let out = render_chat_template(template, &messages, &Value::Null, &Value::Null)
.expect("real Qwen3.6 template should render in minijinja");
let pads = out.matches("<|image_pad|>").count();
assert_eq!(
pads, 1,
"expected exactly one <|image_pad|>; rendered:\n{out}"
);
assert!(out.contains("<|vision_start|>") && out.contains("<|vision_end|>"));
}
fn user_msg(text: &str) -> ChatMessage {
ChatMessage {
role: "user".into(),
content: MessageContent::Text(text.into()),
extra: Value::Object(Default::default()),
}
}
fn assistant_msg(text: &str) -> ChatMessage {
ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(text.into()),
extra: Value::Object(Default::default()),
}
}
/// Minimal Qwen3-style template — enough surface to confirm
/// our renderer threads role + content correctly without
/// loading a real model's tokenizer_config.json.
const QWEN3_LIKE: &str = "{%- for message in messages -%}\
<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n\
{%- endfor -%}\
{%- if add_generation_prompt -%}<|im_start|>assistant\n{%- endif -%}";
#[test]
fn renders_basic_conversation() {
let prompt = render_chat_template(
QWEN3_LIKE,
&[user_msg("hello"), assistant_msg("hi"), user_msg("bye")],
&Value::Null,
&Value::Null,
)
.unwrap();
// Structural assertions — the exact whitespace produced
// by a given template is a Jinja-trim concern that varies
// per real chat_template. What matters is that every
// turn's role + content thread through in order, and that
// the generation cue lands at the end.
assert!(
prompt.contains("<|im_start|>user\nhello<|im_end|>"),
"first user turn missing: {prompt}"
);
assert!(
prompt.contains("<|im_start|>assistant\nhi<|im_end|>"),
"assistant turn missing: {prompt}"
);
assert!(
prompt.contains("<|im_start|>user\nbye<|im_end|>"),
"second user turn missing: {prompt}"
);
assert!(
prompt.ends_with("<|im_start|>assistant")
|| prompt.ends_with("<|im_start|>assistant\n"),
"generation cue missing at end: {prompt}"
);
}
#[test]
fn kwargs_are_threaded_into_template_context() {
// Replica of Qwen3's enable_thinking branch in
// simplified form. When the kwarg is false, the model's
// template injects an empty `<think>...</think>` block
// before the generation cue — pre-filling the model's
// reasoning slot with "no thinking" so the model emits
// the answer directly.
let template = "{%- if enable_thinking is defined and enable_thinking is false -%}\
NO_THINK\
{%- else -%}\
THINK_OK\
{%- endif -%}";
let r_disabled = render_chat_template(
template,
&[],
&Value::Null,
&json!({ "enable_thinking": false }),
)
.unwrap();
assert_eq!(r_disabled, "NO_THINK");
let r_default = render_chat_template(template, &[], &Value::Null, &Value::Null).unwrap();
assert_eq!(r_default, "THINK_OK");
}
#[test]
fn missing_template_field_returns_none() {
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-missing-field.json");
std::fs::write(&tmp, r#"{"some_other_field": 1}"#).unwrap();
assert!(load_chat_template_from(&tmp).is_none());
let _ = std::fs::remove_file(tmp);
}
#[test]
fn load_template_from_string_field() {
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-string.json");
std::fs::write(
&tmp,
r#"{"chat_template": "hello {{ messages[0].content }}"}"#,
)
.unwrap();
let t = load_chat_template_from(&tmp).expect("template loaded");
assert!(t.contains("messages[0].content"));
let _ = std::fs::remove_file(tmp);
}
#[test]
fn load_template_from_array_form() {
// Some HF models ship `chat_template` as `[{name, template}, ...]`.
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-array.json");
std::fs::write(
&tmp,
r#"{"chat_template": [{"name": "default", "template": "ARR"}]}"#,
)
.unwrap();
let t = load_chat_template_from(&tmp).expect("template loaded");
assert_eq!(t, "ARR");
let _ = std::fs::remove_file(tmp);
}
#[test]
fn missing_file_returns_none_quietly() {
let absent = std::path::PathBuf::from("/definitely/not/a/real/path.json");
assert!(load_chat_template_from(&absent).is_none());
}
#[test]
fn unparseable_returns_none() {
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-garbage.json");
std::fs::write(&tmp, b"{not valid json").unwrap();
assert!(load_chat_template_from(&tmp).is_none());
let _ = std::fs::remove_file(tmp);
}
#[test]
fn kill_switch_recognises_truthy_falsy_values() {
// Test against the actual env var so callers see the
// same behaviour as production. Serialise via a
// mutex — see path_util.rs for the pattern.
use std::sync::Mutex;
static LOCK: Mutex<()> = Mutex::new(());
let _g = LOCK.lock().unwrap();
let prior = std::env::var(KILL_SWITCH_ENV).ok();
unsafe {
std::env::remove_var(KILL_SWITCH_ENV);
}
assert!(chat_templates_enabled());
for value in ["false", "0", "no", "off", "FALSE", " no "] {
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
assert!(!chat_templates_enabled(), "value {value:?} should disable");
}
for value in ["true", "1", "yes", ""] {
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
assert!(chat_templates_enabled(), "value {value:?} should enable");
}
unsafe {
match prior {
Some(p) => std::env::set_var(KILL_SWITCH_ENV, p),
None => std::env::remove_var(KILL_SWITCH_ENV),
}
}
}
#[test]
fn message_extras_thread_through_for_tool_calls() {
// HF templates read assistant.tool_calls and tool
// turns' tool_call_id. Confirm our extras flatten into
// the message object the template iterates.
let mut extras = serde_json::Map::new();
extras.insert(
"tool_calls".into(),
json!([{"id": "t1", "function": {"name": "x", "arguments": "{}"}}]),
);
let msg = ChatMessage {
role: "assistant".into(),
content: MessageContent::Text(String::new()),
extra: Value::Object(extras),
};
let template = "{{ messages[0].tool_calls[0].id }}";
let rendered = render_chat_template(template, &[msg], &Value::Null, &Value::Null).unwrap();
assert_eq!(rendered, "t1");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,330 +0,0 @@
//! Job variants accepted by the per-device worker thread.
//!
//! Each variant carries the inputs the synchronous dispatch handler
//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
use anyhow::Result;
use std::path::PathBuf;
use tokio::sync::oneshot;
/// Opaque handle to a `ModelArch` stored in the worker thread's state
/// slab. Cheap to copy; `Send + Sync` so it crosses task boundaries
/// freely. The actual `Box<ModelArch>` it points to is owned by the
/// worker thread for the duration of the handle's lifetime — the only
/// way to drop the model is to send `Job::DropArch { handle }` so the
/// `Drop` impl runs on the thread with the bound CUDA context (the
/// invariant the whole refactor exists to guarantee).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ArchHandle(pub u64);
/// Opaque handle to a `TpLeaderModel` stored in the worker thread's
/// state slab. Same shape as [`ArchHandle`] but in a separate
/// namespace so the two slabs can coexist without ambiguity. Phase 3
/// introduces it; Phase 4 may unify the two slabs after the TP forward
/// path proves out.
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TpHandle(pub u64);
/// Opaque handle to a prefix-cache snapshot (#11) stored worker-side
/// next to the model slab. Scoped to the `ArchHandle` it was captured
/// from — `Job::DropArch` drops every snapshot under its handle. The
/// snapshot's tensors never leave the worker thread; the async side
/// holds only this id plus the token sequence it covers.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct KvSnapshotId(pub u64);
/// One image payload for `Job::ForwardLogitsWithImages` /
/// `Job::EncodeImage`. Pixels are row-major `(c, h, w)` f32 — the
/// shape `harness::preprocess::preprocess` produces. Carries the
/// shape inline since `Vec<f32>` is rank-1.
///
/// `Clone` so the vision-aware dispatch in `chat_completion` can
/// match `&vision_route` (carrying borrowed images) and still hand
/// owned `Vec<ImageInput>` to the worker job. The clone cost is one
/// pixel-buffer memcpy per image — now variable with dynamic resolution
/// (#14): `3 × h × w × 4` bytes, up to ~6.3 MiB at the default 1024²
/// `max_pixels` budget.
///
/// `h`/`w` are the **resized** dims (factor-aligned), so the per-image LM
/// grid is `(h/factor, w/factor)` — derived downstream for the splice
/// and the interleaved-M-RoPE position ids.
#[derive(Clone)]
pub struct ImageInput {
pub pixels: Vec<f32>,
pub c: usize,
pub h: usize,
pub w: usize,
}
/// One unit of work for the device worker.
///
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
/// single-GPU inference primitives: transfer-in a freshly-loaded
/// `ModelArch`, drop it, clear its KV cache, and run one forward step
/// returning CPU-side logits ready for sampling on the async caller.
///
/// Sampling stays on the async side intentionally. The worker copies
/// logits to CPU (`Vec<f32>`) before reply, so the device-resident
/// tensor never escapes the worker thread and the async caller's
/// `LogitsProcessor::sample` runs entirely on the CPU candle backend
/// — no incidental context binding on a tokio worker thread.
pub enum Job {
/// Query free / total VRAM on the device. Returns
/// `(free_mb, total_mb)`. CPU builds and contexts that failed to
/// initialise reply with `(0, 0)` — matches today's
/// `device_vram_mb` sentinel so the log field values don't change.
QueryVram {
reply: oneshot::Sender<Result<(u64, u64)>>,
},
/// Load a GGUF (pre-quantized) single-GPU model on the worker
/// thread. The dispatch handler opens the GGUF file, parses
/// metadata, dispatches on `general.architecture`, and inserts
/// the resulting `ModelArch` into the slab. Returns the fresh
/// `ArchHandle`.
LoadGguf {
gguf_path: PathBuf,
model_id: String,
reply: oneshot::Sender<Result<ArchHandle>>,
},
/// Load a dense safetensors single-GPU model on the worker
/// thread. The dispatch handler reads `config.json`, dispatches on
/// `model_type`, builds a `VarBuilder` over the mmap'd
/// safetensors, and inserts the resulting `ModelArch`.
LoadDense {
config_path: PathBuf,
safetensors_paths: Vec<PathBuf>,
model_id: String,
reply: oneshot::Sender<Result<ArchHandle>>,
},
/// Remove the model from the slab and drop it. The `Drop` runs on
/// the worker thread so CUDA tensors release their memory on the
/// same context that allocated them.
DropArch {
handle: ArchHandle,
reply: oneshot::Sender<()>,
},
/// Reset the KV cache for this model. Called at the start of every
/// chat completion so a new request doesn't attend over the
/// previous one's tokens.
ClearKv {
handle: ArchHandle,
reply: oneshot::Sender<Result<()>>,
},
/// Capture the model's live cache state (attention KV + GDN
/// recurrent state + position counters) as a prefix snapshot
/// (#11). The snapshot stays in the worker's state, keyed by the
/// returned id; the reply carries `(id, bytes)` so the async side
/// can do budget accounting without touching tensors. Errors on
/// archs without snapshot support.
SnapshotKv {
handle: ArchHandle,
reply: oneshot::Sender<Result<(KvSnapshotId, u64)>>,
},
/// Replace the model's live cache state with a stored snapshot,
/// instead of `ClearKv`, so prefill can resume at the snapshot's
/// token boundary. The snapshot remains stored (restorable again).
RestoreKv {
handle: ArchHandle,
snapshot: KvSnapshotId,
reply: oneshot::Sender<Result<()>>,
},
/// Drop one stored snapshot (prefix-cache eviction). Idempotent.
DropKvSnapshot {
handle: ArchHandle,
snapshot: KvSnapshotId,
reply: oneshot::Sender<()>,
},
/// Run one forward step and copy the resulting `[vocab]` logits to
/// CPU. The caller takes the returned `Vec<f32>`, wraps it in a
/// CPU `Tensor`, and runs `apply_repeat_penalty` + sampling
/// without touching the device context. `offset` is the KV-cache
/// position before this step (0 for prefill, `prompt_len + i` for
/// the i-th decode step).
ForwardLogits {
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Run the LM forward with vision splicing in one round-trip.
/// Stage B3 of the vision plan.
///
/// Inputs:
/// - `tokens`: prompt-expanded token ids (the caller has already
/// replaced each `<|image_pad|>` with N copies per the
/// per-image patch count, so `tokens` already contains exactly
/// `sum(n_i)` `image_token_id` entries across all images).
/// - `offset`: KV-cache position (same contract as `ForwardLogits`).
/// - `images`: one entry per image — preprocessed pixels plus the
/// `(c, h, w)` shape. Images are encoded on the worker via the
/// model's vision tower (`VisionTower::forward`), concatenated
/// in order, and spliced into the LM input embeddings at
/// `image_token_id` positions.
/// - `image_token_id`: the sentinel token (248056 for Qwen3.6).
///
/// Returns flat CPU `[vocab]` logits, same as `ForwardLogits`.
/// Image embeddings stay device-resident — they're never copied
/// to CPU. The "tensors don't escape the worker" invariant holds.
ForwardLogitsWithImages {
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
images: Vec<ImageInput>,
image_token_id: u32,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Encode one image through the model's vision tower. Stage A5 of
/// the vision plan (`doc/vision-qwen3_6-spec.md`).
///
/// `pixels` is the CPU-side preprocessed image tensor in row-major
/// `(C, H, W)` f32 layout — what `harness::preprocess::preprocess`
/// produces. `c`, `h`, `w` carry the shape since `Vec<f32>` itself
/// is rank-1. The handler reconstructs the tensor on the worker's
/// device, runs `VisionTower::forward`, copies the resulting
/// `(N_lm_tokens, hidden_size)` embedding back to CPU as a flat
/// `Vec<f32>` (the caller knows the expected shape from
/// `VisionTower::lm_tokens_for(h, w) * hidden_size`).
///
/// Mirrors the `ForwardLogits` "tensors don't escape" invariant —
/// device-side image embeddings are dropped at handler return.
/// Stage B will introduce a follow-up variant that keeps the
/// embeddings device-resident and references them from the next
/// `ForwardLogits` call, avoiding the round-trip copy.
EncodeImage {
handle: ArchHandle,
pixels: Vec<f32>,
c: usize,
h: usize,
w: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Initialize the leader's NCCL communicator. The worker's
/// `NcclState` mints the `Comm` here so its underlying
/// `ncclComm_t` and `CudaContext` live on the same thread as
/// every later `Comm::all_reduce` call. Reply is the worker
/// response shape used by the subprocess workers (`InitOk` on
/// success, `Error` on failure) so the calling
/// `WorkerPool::init_nccl` orchestration stays uniform.
///
/// Available on both cuda and no-cuda builds — the dispatch
/// handler calls `NcclState::init` which has a no-cuda stub that
/// replies with `cuda_feature_not_enabled`. Keeping the Job
/// variant ungated lets `WorkerPool::init_nccl` stay uniform.
NcclInit {
cfg: crate::harness::tp::worker::WorkerConfig,
comm_id_hex: String,
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
},
/// Run NCCL's all_reduce sanity check on the leader's rank 0.
/// Same response shape as `NcclInit`; also available on both
/// builds via the no-cuda `NcclState::sanity_check` stub.
NcclSanity {
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
},
/// Hand a clonable handle to the leader's NCCL `Comm` back to the
/// async side, so the TP step watchdog can call `ncclCommAbort` on
/// it from a *different* thread to unblock a wedged collective
/// (#17 Stage 2). Fetched once at init while the worker thread is
/// still responsive — a thread already wedged in a collective can't
/// service this job, which is exactly why the handle is cached
/// up front. Replies `None` before `NcclInit` has run.
#[cfg(feature = "cuda")]
GetLeaderComm {
reply: oneshot::Sender<Option<crate::harness::tp::nccl_state::SendComm>>,
},
/// Load the leader's TP shard on the worker thread. The dispatch
/// handler reads `state.nccl.comm()` directly (no cross-thread
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
/// `TpLeaderModel` against that Comm. The model's embedded
/// `Arc<Comm>` clones, `CudaContext`, and all per-rank CUDA
/// tensors live on this thread for the model's lifetime.
/// Inserts into the TP slab and returns the fresh `TpHandle`.
#[cfg(feature = "cuda")]
TpLoadShard {
model_id: String,
config_json: String,
safetensors_paths: Vec<PathBuf>,
dtype: candle_core::DType,
quant: Option<String>,
world_size: u32,
reply: oneshot::Sender<Result<TpHandle>>,
},
/// Drop the TP leader model on the worker thread. CUDA tensors
/// and `Arc<Comm>` clones held inside the model release on the
/// thread that allocated them.
#[cfg(feature = "cuda")]
DropTp {
handle: TpHandle,
reply: oneshot::Sender<()>,
},
/// Reset the leader's KV cache for a TP model. Mirrors `ClearKv`
/// for single-GPU.
#[cfg(feature = "cuda")]
TpClearKv {
handle: TpHandle,
reply: oneshot::Sender<Result<()>>,
},
/// Capture the leader's TP cache state as a prefix snapshot (#11),
/// stored worker-side under the pool-minted `snapshot_id` (shared
/// with the subprocess ranks, so all ranks key the same snapshot
/// identically). Replies with the leader shard's snapshot bytes.
#[cfg(feature = "cuda")]
TpSnapshotKv {
handle: TpHandle,
snapshot_id: u64,
reply: oneshot::Sender<Result<u64>>,
},
/// Replace the leader's live TP cache state with a stored
/// snapshot. Mirrors `RestoreKv` for single-GPU.
#[cfg(feature = "cuda")]
TpRestoreKv {
handle: TpHandle,
snapshot_id: u64,
reply: oneshot::Sender<Result<()>>,
},
/// Drop one stored leader TP snapshot (eviction). Idempotent.
#[cfg(feature = "cuda")]
TpDropKvSnapshot {
handle: TpHandle,
snapshot_id: u64,
reply: oneshot::Sender<()>,
},
/// Run one TP forward step on the leader's shard. Returns CPU-
/// side logits as a `Vec<f32>` so the async caller can sample
/// without holding a device tensor. The caller is also
/// responsible for fan-out to subprocess ranks and drain — only
/// the leader's forward moves into the worker thread.
#[cfg(feature = "cuda")]
TpForwardLogits {
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Image-bearing leader (rank 0) forward for the single-shot vision
/// prefill. The handler preprocesses each `image_data_uris` entry
/// (the same deterministic path every rank runs), encodes through
/// the leader's replicated tower, splices at `image_token_id`, and
/// returns CPU-side `[vocab]` logits. Image tensors never escape the
/// worker thread. Caller fans out `GenerateStepWithImages` to the
/// subprocess ranks and drains them; only the leader forward moves
/// here.
#[cfg(feature = "cuda")]
TpForwardLogitsWithImages {
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
image_token_id: u32,
image_data_uris: Vec<String>,
chunk_size: usize,
reply: oneshot::Sender<Result<Vec<f32>>>,
},
/// Tell the worker to break its dispatch loop and exit. Any jobs
/// queued after this in the channel reply `Err` to their oneshot
/// senders (the senders are dropped on the worker's exit, which
/// the async-side `Receiver::await` maps to `WorkerError::Gone`).
Shutdown,
}

View File

@@ -1,948 +0,0 @@
//! Per-device CUDA worker thread.
//!
//! One dedicated OS thread per CUDA device the leader uses. The thread
//! binds the device's `CudaContext` once at startup and owns it for the
//! daemon's lifetime; all GPU operations and VRAM queries for that
//! device route through a `std::sync::mpsc` channel into this thread.
//! Tensors never escape the thread alive — replies cross the channel
//! as plain values (`u32` tokens, `(u64, u64)` mb numbers, `()`).
//!
//! Rationale, in order of weight:
//!
//! 1. **Context locality.** cudarc binds the CUDA context per OS thread
//! via `cuCtxSetCurrent`. With `tokio::task::spawn_blocking`, the
//! blocking thread chosen is arbitrary, so the context gets bound
//! onto a different thread each time and `device_vram_mb()` from an
//! async task binds it again on the *caller's* thread as a side
//! effect. Pinning the context to one named thread ends that.
//!
//! 2. **Drop safety.** `cudarc::driver::CudaContext`, every `CudaSlice`
//! inside a `Tensor`, and every `cudarc::nccl::Comm` call `cuMemFree`
//! / `cuCtxDestroy` / `ncclCommDestroy` during `Drop`. These must
//! run with the right context current. Owning everything in this
//! thread's state slab and dropping it via `Job::DropArch` /
//! `Job::Shutdown` is the only safe pattern.
//!
//! 3. **Poisoning blast radius.** When a CUDA driver error (illegal
//! address, OOM cascade) makes the context unrecoverable, today the
//! spawn_blocking thread carrying that bad state simply returns to
//! tokio's pool — invisible. With the per-device thread, the
//! poisoned flag lives on the thread itself; subsequent
//! `submit()` calls fast-reject at the channel boundary with a
//! clear "device worker is poisoned" error before any further CUDA
//! work is attempted.
//!
//! The TP worker subprocesses (`harness/tp/worker.rs`) are already this
//! pattern, just out-of-process. The in-process variant uses the same
//! discipline for rank 0.
//!
//! Phase 1 of the refactor exposes only `Job::QueryVram` + `Job::Shutdown`.
//! Forward, kv-cache clear, model load, and NCCL bring-up move in later
//! phases. See `/home/grenade/.claude/plans/plan-the-per-device-worker-abstract-micali.md`.
pub mod dispatch;
pub mod jobs;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{self, Sender};
use std::thread::JoinHandle;
use tokio::sync::oneshot;
#[cfg(feature = "cuda")]
pub use jobs::TpHandle;
pub use jobs::{ArchHandle, Job, KvSnapshotId};
/// Errors returned by `DeviceWorkerHandle` submit methods.
#[derive(Debug, thiserror::Error)]
pub enum WorkerError {
/// The worker's CUDA context was poisoned by an earlier driver
/// error. The thread is still alive (dropping it would re-touch
/// the broken context); it returns this error for every job
/// submitted until the daemon is restarted.
#[error(
"device worker for device {device_index} is poisoned \
(a prior CUDA driver error left the context unrecoverable); \
restart the daemon to recover"
)]
Poisoned { device_index: u32 },
/// The worker thread has exited (`Job::Shutdown` was processed or
/// the thread panicked). Subsequent `submit()` calls fail here
/// rather than blocking forever.
#[error("device worker for device {device_index} is no longer running")]
Gone { device_index: u32 },
/// The dispatched job returned an `Err`. Forwarded verbatim.
#[error(transparent)]
Job(#[from] anyhow::Error),
}
/// Shared handle to a per-device CUDA worker thread.
///
/// Cloning the `Arc` lets multiple `LoadedModel`s (and `TpLoadedModel`s)
/// share the same worker — there's one worker per CUDA device index,
/// not one per model.
pub struct DeviceWorkerHandle {
device_index: u32,
tx: Sender<Job>,
poisoned: Arc<AtomicBool>,
/// `Mutex<Option<JoinHandle>>` so `shutdown()` can take the handle
/// out without `&mut self` and so the inevitable `Drop` after
/// `shutdown()` doesn't double-join. The mutex is uncontended in
/// practice: only one caller ever takes the handle.
join: std::sync::Mutex<Option<JoinHandle<()>>>,
}
impl DeviceWorkerHandle {
/// Spawn a new worker for the given CUDA device index.
///
/// The thread is named `cuda-dev-N` so it shows up legibly in
/// `top -H`, `pidstat -t`, and gdb backtraces. On CUDA builds, the
/// thread binds `CudaContext::new(N)` on startup; on CPU builds
/// (`--no-default-features`) the thread runs without a context and
/// every job that touches CUDA falls through to a zero return.
pub fn spawn(device_index: u32) -> anyhow::Result<Arc<Self>> {
let (tx, rx) = mpsc::channel::<Job>();
let poisoned = Arc::new(AtomicBool::new(false));
let poisoned_for_thread = Arc::clone(&poisoned);
let join = std::thread::Builder::new()
.name(format!("cuda-dev-{device_index}"))
.spawn(move || {
dispatch::run(device_index, rx, poisoned_for_thread);
})?;
Ok(Arc::new(Self {
device_index,
tx,
poisoned,
join: std::sync::Mutex::new(Some(join)),
}))
}
pub fn device_index(&self) -> u32 {
self.device_index
}
pub fn is_poisoned(&self) -> bool {
self.poisoned.load(Ordering::Acquire)
}
/// Mark the worker's context as poisoned. Future `submit()` calls
/// short-circuit to `WorkerError::Poisoned` before sending. The
/// dispatch loop also flips into drain-only mode when it sees this
/// flag, so any jobs already in flight on the channel reply with
/// the same error without touching CUDA.
#[allow(dead_code)]
pub(crate) fn set_poisoned(&self) {
self.poisoned.store(true, Ordering::Release);
}
/// Send `Job::QueryVram`, await the worker's reply.
///
/// Returns `Ok((free_mb, total_mb))` on success, `Ok((0, 0))` on
/// CPU builds or when the device lacks a bound context, or an
/// error if the worker is poisoned, gone, or the query itself
/// failed inside cudarc.
pub async fn query_vram(&self) -> Result<(u64, u64), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::QueryVram { reply: reply_tx })
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Fetch a clonable handle to the leader's NCCL `Comm` (#17 Stage 2).
/// The TP step watchdog caches this at init so it can call
/// `ncclCommAbort` from the async thread to unblock a wedged
/// collective. Returns `None` if uninitialised, poisoned, or gone —
/// the caller treats a missing handle as "can't abort" and logs it.
#[cfg(feature = "cuda")]
pub async fn get_leader_comm(&self) -> Option<crate::harness::tp::nccl_state::SendComm> {
if self.poisoned.load(Ordering::Acquire) {
return None;
}
let (reply_tx, reply_rx) = oneshot::channel();
if self
.tx
.send(Job::GetLeaderComm { reply: reply_tx })
.is_err()
{
return None;
}
reply_rx.await.ok().flatten()
}
/// Load a GGUF (pre-quantized) single-GPU model on the worker
/// thread. The hf-hub resolution happens on the async caller; the
/// resolved local `gguf_path` plus the spec's model_id are sent
/// into the worker which opens, parses, and constructs the
/// `ModelArch` on the right thread.
pub async fn load_gguf(
&self,
gguf_path: std::path::PathBuf,
model_id: String,
) -> Result<ArchHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::LoadGguf {
gguf_path,
model_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Load a dense safetensors single-GPU model on the worker thread.
pub async fn load_dense(
&self,
config_path: std::path::PathBuf,
safetensors_paths: Vec<std::path::PathBuf>,
model_id: String,
) -> Result<ArchHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::LoadDense {
config_path,
safetensors_paths,
model_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Tell the worker to drop the `ModelArch` for `handle` on the
/// worker thread (so CUDA tensors release on the right context).
/// Returns `Ok(())` even if the handle wasn't in the slab — Drop
/// is idempotent. Reports `Gone` if the worker isn't running.
pub async fn drop_arch(&self, handle: ArchHandle) -> Result<(), WorkerError> {
// Poisoning doesn't block DropArch — even on a poisoned
// context we want callers to unblock and proceed with the
// unload bookkeeping. The dispatch handler under poison just
// replies `()` without touching the model (the actual Drop
// happens via mem::forget at thread exit per the poison
// protocol).
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::DropArch {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(()) => Ok(()),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Reset the KV cache for the model at `handle`. Called at the
/// start of every chat completion so the new prompt doesn't
/// attend over the previous request's tokens.
pub async fn clear_kv_cache(&self, handle: ArchHandle) -> Result<(), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::ClearKv {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Capture the model's live cache state as a worker-side prefix
/// snapshot (#11). Returns the snapshot id plus its byte size for
/// the async-side budget accounting. Tensors stay on the worker.
pub async fn snapshot_kv(
&self,
handle: ArchHandle,
) -> Result<(jobs::KvSnapshotId, u64), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::SnapshotKv {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Replace the model's live cache state with a stored snapshot —
/// called instead of [`Self::clear_kv_cache`] on a prefix-cache
/// hit. The snapshot remains stored and restorable again.
pub async fn restore_kv(
&self,
handle: ArchHandle,
snapshot: jobs::KvSnapshotId,
) -> Result<(), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::RestoreKv {
handle,
snapshot,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Drop one stored prefix snapshot (eviction). Mirrors
/// [`Self::drop_arch`]'s poison-tolerant unit-reply shape so
/// bookkeeping always unblocks.
pub async fn drop_kv_snapshot(
&self,
handle: ArchHandle,
snapshot: jobs::KvSnapshotId,
) -> Result<(), WorkerError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::DropKvSnapshot {
handle,
snapshot,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(()) => Ok(()),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Run one forward step and return the resulting `[vocab]` logits
/// as a CPU-side `Vec<f32>`. The caller then samples on a CPU
/// candle Tensor without ever binding the device context on its
/// tokio thread.
pub async fn forward_logits(
&self,
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::ForwardLogits {
handle,
tokens,
offset,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Forward with image-aware splicing in one round-trip. Stage B3.
///
/// Encodes each image on the worker thread (device-resident), then
/// runs the LM forward with the embeddings spliced at
/// `image_token_id` positions. Returns CPU `[vocab]` logits, same
/// shape as `forward_logits`. Image embeddings never copy back to
/// CPU.
pub async fn forward_logits_with_images(
&self,
handle: ArchHandle,
tokens: Vec<u32>,
offset: usize,
images: Vec<crate::harness::device_worker::jobs::ImageInput>,
image_token_id: u32,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::ForwardLogitsWithImages {
handle,
tokens,
offset,
images,
image_token_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Encode a preprocessed image through the model's vision tower
/// and return the resulting LM-side image embeddings as a
/// flattened CPU `Vec<f32>`. Stage A5.
///
/// `pixels` is the row-major `(c, h, w)` f32 image —
/// `harness::preprocess::preprocess` produces this exact shape.
/// The caller knows the expected output length from
/// `VisionTower::lm_tokens_for(h, w) * hidden_size` and reshapes
/// accordingly.
pub async fn encode_image(
&self,
handle: ArchHandle,
pixels: Vec<f32>,
c: usize,
h: usize,
w: usize,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::EncodeImage {
handle,
pixels,
c,
h,
w,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Initialise the leader's NCCL communicator. The reply uses
/// `WorkerResponse` (same shape subprocess workers use over stdio
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
/// subprocess responses uniformly. Available on no-cuda builds
/// too — the dispatch handler calls the no-cuda `NcclState::init`
/// stub which replies `cuda_feature_not_enabled`.
pub async fn nccl_init(
&self,
cfg: crate::harness::tp::worker::WorkerConfig,
comm_id_hex: String,
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::NcclInit {
cfg,
comm_id_hex,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
reply_rx.await.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})
}
/// Run an NCCL sanity all_reduce on the leader's rank 0.
/// Available on no-cuda builds; replies with an error response.
pub async fn nccl_sanity(
&self,
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::NcclSanity { reply: reply_tx })
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
reply_rx.await.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})
}
/// Load the leader's TP shard on the worker thread. The dispatch
/// handler reads its own `NcclState`'s `Arc<Comm>` directly — no
/// cross-thread Comm transfer — and builds the `TpLeaderModel`
/// against it. Phase 4 replaces the Phase 3 Clone/TransferIn
/// bridge with this single Job.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn tp_load_shard(
&self,
model_id: String,
config_json: String,
safetensors_paths: Vec<std::path::PathBuf>,
dtype: candle_core::DType,
quant: Option<String>,
world_size: u32,
) -> Result<TpHandle, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpLoadShard {
model_id,
config_json,
safetensors_paths,
dtype,
quant,
world_size,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Drop the TP model at `handle` on the worker thread.
#[cfg(feature = "cuda")]
pub async fn drop_tp(&self, handle: TpHandle) -> Result<(), WorkerError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::DropTp {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(()) => Ok(()),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Reset the leader's KV cache for a TP model.
#[cfg(feature = "cuda")]
pub async fn tp_clear_kv(&self, handle: TpHandle) -> Result<(), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpClearKv {
handle,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Capture the leader's TP cache state as a prefix snapshot (#11)
/// stored under the pool-minted `snapshot_id`. Returns the leader
/// shard's snapshot bytes.
#[cfg(feature = "cuda")]
pub async fn tp_snapshot_kv(
&self,
handle: TpHandle,
snapshot_id: u64,
) -> Result<u64, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpSnapshotKv {
handle,
snapshot_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Replace the leader's live TP cache state with a stored
/// snapshot — called instead of [`Self::tp_clear_kv`] on a
/// prefix-cache hit.
#[cfg(feature = "cuda")]
pub async fn tp_restore_kv(
&self,
handle: TpHandle,
snapshot_id: u64,
) -> Result<(), WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpRestoreKv {
handle,
snapshot_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Drop one stored leader TP snapshot (eviction). Poison-tolerant
/// unit reply, same shape as [`Self::drop_kv_snapshot`].
#[cfg(feature = "cuda")]
pub async fn tp_drop_kv_snapshot(
&self,
handle: TpHandle,
snapshot_id: u64,
) -> Result<(), WorkerError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpDropKvSnapshot {
handle,
snapshot_id,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(()) => Ok(()),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Run one TP forward step on the leader's shard. Returns CPU-side
/// logits as `Vec<f32>` ready for sampling. The caller is
/// responsible for fan-out / drain of the subprocess workers
/// concurrently with this call.
#[cfg(feature = "cuda")]
pub async fn tp_forward_logits(
&self,
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpForwardLogits {
handle,
tokens,
offset,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Image-bearing TP leader forward (single-shot vision prefill).
/// Routes `Job::TpForwardLogitsWithImages` onto the worker thread;
/// the handler preprocesses + encodes + splices + forwards and
/// returns CPU-side `[vocab]` logits. The `WorkerPool` fans the
/// matching `GenerateStepWithImages` out to subprocess ranks so the
/// row-parallel collectives complete.
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub async fn tp_forward_logits_with_images(
&self,
handle: TpHandle,
tokens: Vec<u32>,
offset: usize,
image_token_id: u32,
image_data_uris: Vec<String>,
chunk_size: usize,
) -> Result<Vec<f32>, WorkerError> {
if self.poisoned.load(Ordering::Acquire) {
return Err(WorkerError::Poisoned {
device_index: self.device_index,
});
}
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Job::TpForwardLogitsWithImages {
handle,
tokens,
offset,
image_token_id,
image_data_uris,
chunk_size,
reply: reply_tx,
})
.map_err(|_| WorkerError::Gone {
device_index: self.device_index,
})?;
match reply_rx.await {
Ok(result) => result.map_err(WorkerError::from),
Err(_) => Err(WorkerError::Gone {
device_index: self.device_index,
}),
}
}
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
/// twice is a no-op the second time.
pub fn shutdown(&self) -> anyhow::Result<()> {
// Best-effort send: if the channel is already closed (thread
// exited after a prior shutdown or panic) the send fails and
// we fall through to the join which returns the panic, if any.
let _ = self.tx.send(Job::Shutdown);
let join = self.join.lock().unwrap().take();
if let Some(j) = join {
j.join()
.map_err(|_| anyhow::anyhow!("worker thread panicked during shutdown"))?;
}
Ok(())
}
}
impl Drop for DeviceWorkerHandle {
fn drop(&mut self) {
// Best-effort: send Shutdown so the thread breaks its loop
// and exits. We do NOT join here — Drop may run on a tokio
// worker thread, and joining a thread that's still processing
// the last job would block the runtime. The OS reaps the
// thread on detach.
let _ = self.tx.send(Job::Shutdown);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn spawn_query_vram_shutdown() {
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
// CPU build (the only one CI runs) returns (0, 0) by design;
// a CUDA build with a real device would return real values.
let result = handle.query_vram().await.expect("query ok");
// We assert >= 0 — the field width matters more than the value.
let _ = result.0;
let _ = result.1;
handle.shutdown().expect("shutdown ok");
}
#[tokio::test]
async fn thread_is_named_correctly() {
// The thread name lets `top -H` / pidstat / gdb show
// `cuda-dev-N` instead of an opaque tokio worker name. Verify
// by spawning and reading proc-self thread comms — but on
// platforms without /proc, just confirm we don't crash.
let handle = DeviceWorkerHandle::spawn(7).expect("spawn ok");
// Round-trip a job to ensure the thread is alive and processing.
handle.query_vram().await.expect("query ok");
handle.shutdown().expect("shutdown ok");
}
#[tokio::test]
async fn submit_after_shutdown_returns_gone() {
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
handle.shutdown().expect("shutdown ok");
// Channel closed; submit should map to Gone rather than block.
let result = handle.query_vram().await;
match result {
Err(WorkerError::Gone { device_index: 0 }) => {}
other => panic!("expected Gone, got {other:?}"),
}
}
#[tokio::test]
async fn poisoned_flag_short_circuits_submit() {
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
handle.set_poisoned();
let result = handle.query_vram().await;
match result {
Err(WorkerError::Poisoned { device_index: 0 }) => {}
other => panic!("expected Poisoned, got {other:?}"),
}
// The channel is still alive; shutdown should still succeed.
handle.shutdown().expect("shutdown ok");
}
/// Stage A5: confirm the EncodeImage job round-trips through the
/// worker channel. We don't have a real loaded model in the slab
/// here, so the dispatch handler returns the
/// "no model for handle" error — which is exactly what we want to
/// see, since it proves the message routed through the channel
/// and reached the handler. Real-weights validation lives in the
/// Stage A7 / Stage B post-deploy smoke on beast.
#[tokio::test]
async fn encode_image_routes_to_dispatch_and_errors_on_unknown_handle() {
use crate::harness::device_worker::jobs::ArchHandle;
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
let fake_arch = ArchHandle(99_999);
// (3, 4, 4) fake image — minimal payload, gets reconstructed
// on the worker before the handler errors out on the unknown
// ArchHandle lookup.
let pixels = vec![0.0_f32; 3 * 4 * 4];
let result = handle.encode_image(fake_arch, pixels, 3, 4, 4).await;
match result {
Err(WorkerError::Job(e)) => {
let msg = format!("{e:#}");
assert!(
msg.contains("EncodeImage: no model for handle"),
"expected unknown-handle error, got: {msg}"
);
}
other => panic!("expected Job(Err), got {other:?}"),
}
handle.shutdown().expect("shutdown ok");
}
#[tokio::test]
async fn shutdown_drains_pending_jobs() {
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
// Submit many concurrent jobs; they should all complete even
// though a Shutdown is racing them.
let mut futures = Vec::new();
for _ in 0..16 {
let h = Arc::clone(&handle);
futures.push(tokio::spawn(async move { h.query_vram().await }));
}
// Small yield to give the senders a chance to actually send
// before we issue the shutdown; not strictly necessary because
// the channel is FIFO, but makes the test's intent clearer.
tokio::time::sleep(Duration::from_millis(10)).await;
handle.shutdown().expect("shutdown ok");
for f in futures {
// Each query should have completed (Ok or Gone, never panic).
let _ = f.await.expect("task did not panic");
}
}
}

View File

@@ -0,0 +1 @@
// llama.cpp harness implementation — Phase 11.

View File

@@ -0,0 +1,163 @@
//! mistral.rs harness implementation.
//!
//! Wraps the mistral.rs HTTP API for model lifecycle management
//! and optionally manages the process via systemd.
use anyhow::Result;
use async_trait::async_trait;
use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec};
use reqwest::Client;
use serde::Deserialize;
pub struct MistralRsHarness {
endpoint: String,
systemd_unit: Option<String>,
client: Client,
}
impl MistralRsHarness {
pub fn new(endpoint: String, systemd_unit: Option<String>) -> Self {
Self {
endpoint,
systemd_unit,
client: Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("failed to build HTTP client"),
}
}
}
/// Response from mistral.rs `GET /v1/models`.
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelEntry>,
}
#[derive(Debug, Deserialize)]
struct ModelEntry {
id: String,
#[serde(default)]
status: Option<String>,
}
#[async_trait]
impl Harness for MistralRsHarness {
fn name(&self) -> &str {
"mistralrs"
}
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
let Some(unit) = &self.systemd_unit else {
anyhow::bail!("no systemd unit configured for mistralrs harness");
};
let output = tokio::process::Command::new("systemctl")
.args(["start", unit])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl start {unit} failed: {stderr}");
}
// Wait for the health endpoint to respond (up to 30s).
let url = format!("{}/health", self.endpoint);
for _ in 0..30 {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if self.client.get(&url).send().await.is_ok() {
tracing::info!(unit, "mistralrs started and healthy");
return Ok(());
}
}
anyhow::bail!("mistralrs started but health endpoint did not respond within 30s");
}
async fn stop(&self) -> Result<()> {
let Some(unit) = &self.systemd_unit else {
anyhow::bail!("no systemd unit configured for mistralrs harness");
};
let output = tokio::process::Command::new("systemctl")
.args(["stop", unit])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("systemctl stop {unit} failed: {stderr}");
}
Ok(())
}
async fn health(&self) -> HarnessHealth {
let url = format!("{}/health", self.endpoint);
let running = self.client.get(&url).send().await.is_ok();
HarnessHealth {
name: "mistralrs".into(),
running,
uptime_secs: None,
}
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
let url = format!("{}/v1/models", self.endpoint);
let resp = self.client.get(&url).send().await?;
if !resp.status().is_success() {
anyhow::bail!("GET /v1/models returned {}", resp.status());
}
let models_resp: ModelsResponse = resp.json().await?;
Ok(models_resp
.data
.into_iter()
.map(|m| ModelInfo {
id: m.id,
harness: "mistralrs".into(),
status: m.status.unwrap_or_else(|| "loaded".into()),
devices: vec![],
vram_used_mb: None,
})
.collect())
}
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
let url = format!("{}/v1/models/reload", self.endpoint);
let resp = self
.client
.post(&url)
.json(&serde_json::json!({ "model_id": spec.model_id }))
.send()
.await?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("POST /v1/models/reload failed: {body}");
}
Ok(())
}
async fn unload_model(&self, model_id: &str) -> Result<()> {
let url = format!("{}/v1/models/unload", self.endpoint);
let resp = self
.client
.post(&url)
.json(&serde_json::json!({ "model_id": model_id }))
.send()
.await?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("POST /v1/models/unload failed: {body}");
}
Ok(())
}
async fn inference_endpoint(&self, _model_id: &str) -> Option<String> {
// mistral.rs routes internally by model name in the request body,
// so the inference endpoint is always the base URL.
Some(self.endpoint.clone())
}
}

Some files were not shown because too many files have changed in this diff Show More